Skip to content

Commit

Permalink
Merge branch 'wavewave-flakes-clang-TH' into wavewave-flakes-clang-2
Browse files Browse the repository at this point in the history
  • Loading branch information
wavewave committed May 31, 2022
2 parents 809966f + 55a3009 commit 60b73bc
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Categorifier/C/CExpr/File.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ data FunctionGenErrorInfo
= BogusInputNodes (Arrays (Compose IntMap (CExprF (Const Int))))
| InvalidGraph GraphFailure

instance Show FunctionGenErrorInfo where
show BogusInputNodes {} = "BogusInputNodes"
show InvalidGraph {} = "GraphFailure"

-- | AssignmentError GenError

data FunctionText ann = FunctionText
Expand Down
10 changes: 8 additions & 2 deletions Categorifier/C/CExpr/Function.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import Categorifier.C.Graph.Reify
)
import Categorifier.C.Prim (Arrays, IsPrimitive)
import Categorifier.C.Recursion (hembed)
import Categorifier.Common.IO.Exception (Exception)
import Control.Monad (unless, (<=<))
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, except, runExceptT)
Expand Down Expand Up @@ -72,9 +73,14 @@ data FunctionGenError = FunctionGenError
{ functionGenErrorName :: Text,
functionGenErrorInfo :: FunctionGenErrorInfo
}
deriving (Show)

instance Exception FunctionGenError

generateTopLevelFunction' ::
Text -> Set ReadyToGenerate -> Either FunctionGenError (FunctionText ann)
Text ->
Set ReadyToGenerate ->
Either FunctionGenError (FunctionText ann)
generateTopLevelFunction' functionName rtg = do
FunctionText topLevelHeader topLevelSource <-
fmap mconcat . traverse (first (FunctionGenError functionName) . generateFunctionText) $
Expand All @@ -92,7 +98,7 @@ generateTopLevelFunction' functionName rtg = do
foldr
(\inc acc -> acc <> Doc.line <> includeUserHeader inc)
mempty
(Doc.pretty (functionName <> ".h") : funcallHeaders)
funcallHeaders
fullHdr = Doc.vcat [spam, cExprHeaders, wrapWithExternC topLevelHeader]
fullSrc = Doc.vcat [spam, cExprHeaders, userHeaders, topLevelSource]
pure $ FunctionText fullHdr fullSrc
Expand Down
32 changes: 32 additions & 0 deletions examples/categorifier-c-examples.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,35 @@ executable c-wrappers
, categorifier-unconcat-category
, categorifier-unconcat-integration
default-language: Haskell2010

executable th-compile
hs-source-dirs: th-compile
main-is: Main.hs
other-modules:
F
ghc-options:
-O0
-fexpose-all-unfoldings
-fmax-simplifier-iterations=0
-fno-ignore-interface-pragmas
-fno-omit-interface-pragmas
-Wall
-fplugin Categorifier
-fplugin-opt Categorifier:hierarchy:Categorifier.Hierarchy.UnconCat.hierarchy
-fplugin-opt Categorifier:hierarchy:Categorifier.Hierarchy.ConCat.functionHierarchy
-fplugin-opt Categorifier:hierarchy:Categorifier.Hierarchy.ConCatExtensions.hierarchy
build-depends:
base
, ghc-prim
, concat-classes
, categorifier-c
, categorifier-c-test-lib
, categorifier-category
, categorifier-client
, categorifier-concat-extensions-category
, categorifier-concat-extensions-integration
, categorifier-concat-integration
, categorifier-plugin
, categorifier-unconcat-category
, categorifier-unconcat-integration
default-language: Haskell2010
64 changes: 64 additions & 0 deletions examples/th-compile/F.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module F
( Input (..),
Output (..),
wrap_f,
)
where

import qualified Categorifier.C.CExpr.Cat as C
import Categorifier.C.CExpr.Cat.TargetOb (TargetOb)
import Categorifier.C.CTypes.CGeneric (CGeneric)
import qualified Categorifier.C.CTypes.CGeneric as CG
import Categorifier.C.CTypes.GArrays (GArrays)
import Categorifier.C.KTypes.C (C)
import qualified Categorifier.Categorify as Categorify
import Categorifier.Client (deriveHasRep)
import Data.Int (Int32)
import Data.Word (Word64)
import GHC.Generics (Generic)

data Input = Input
{ iInt32 :: C Int32,
iDouble :: C Double
}
deriving (Generic, Show)

deriveHasRep ''Input

instance CGeneric Input

instance GArrays C Input

type instance TargetOb Input = TargetOb (CG.Rep Input ())

data Output = Output
{ oWord64 :: C Word64,
oFloat :: C Float,
oBool :: Bool
}
deriving (Generic, Show)

deriveHasRep ''Output

instance CGeneric Output

instance GArrays C Output

type instance TargetOb Output = TargetOb (CG.Rep Output ())

f :: Input -> Output
f inp =
Output
{ oWord64 = if odd (iInt32 inp) then fromIntegral (iInt32 inp) + 5 else 42,
oFloat = realToFrac $ min (iDouble inp) 3.14,
oBool = iDouble inp > 0
}

$(Categorify.function 'f [t|C.Cat|] [])
16 changes: 16 additions & 0 deletions examples/th-compile/Main.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

module Main (main) where

import Categorifier.C.Codegen.FFI.TH (embedFunction)
import F (Input (..), Output (..), wrap_f)

$(embedFunction "simple_example" wrap_f)

main :: IO ()
main = do
let input = Input 1 5.0
output <- hs_simple_example input
print output
97 changes: 97 additions & 0 deletions test-lib/Categorifier/C/Codegen/FFI/TH.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- Data.Text.Prettyprint.Doc.Render.Text is deprecated.
{-# OPTIONS_GHC -fno-warn-deprecations #-}

module Categorifier.C.Codegen.FFI.TH (embedFunction) where

import qualified Categorifier.C.CExpr.Cat as C
import Categorifier.C.CExpr.Cat.TargetOb (TargetOb)
import qualified Categorifier.C.CExpr.File as CExpr (FunctionText (..))
import qualified Categorifier.C.CExpr.IO as CExpr (layoutOptions)
import Categorifier.C.CExpr.Types.Core (CExpr)
import Categorifier.C.Codegen.FFI.ArraysCC (fromArraysCC)
import Categorifier.C.Codegen.FFI.Spec (SBVFunCall)
import Categorifier.C.KTypes.C (C)
import Categorifier.C.KTypes.CExpr.Generate (generateCExprFunction)
import Categorifier.C.PolyVec (PolyVec, pdevectorize, pvectorize, pvlengths)
import Categorifier.C.Prim (ArrayCount, Arrays)
import qualified Categorifier.Common.IO.Exception as Exception
import Control.Monad ((<=<))
import Data.Functor.Compose (Compose (..))
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Prettyprint.Doc.Render.Text as Prettyprint
import Data.Typeable (Typeable)
import Data.Vector (Vector)
import Language.Haskell.TH.Syntax
( Body (NormalB),
Callconv (..),
Clause (..),
Dec (ForeignD, FunD, SigD),
Exp (..),
Foreign (ImportF),
ForeignSrcLang (LangC),
Pat (VarP),
Q,
Safety (Safe),
Type (..),
)
import qualified Language.Haskell.TH.Syntax as TH
import qualified Type.Reflection as TR

arraysFun ::
forall i o.
(PolyVec CExpr (TargetOb i), PolyVec CExpr (TargetOb o)) =>
(i `C.Cat` o) ->
Arrays (Compose Vector CExpr) ->
IO (Arrays (Compose Vector CExpr))
arraysFun f =
Exception.throwIOLeft . pvectorize . C.lowerCat f <=< Exception.throwIOLeft . pdevectorize

inputDims :: forall a. PolyVec C a => Proxy a -> Arrays ArrayCount
inputDims = pvlengths (Proxy @C)

getTypeName :: forall t. (Typeable t) => Proxy t -> String
getTypeName p =
let tRep = TR.someTypeRep p -- (Proxy @t)
tCon = TR.someTypeRepTyCon tRep
in TR.tyConName tCon

embedFunction ::
forall i o.
(Typeable i, Typeable o, PolyVec CExpr (TargetOb i), PolyVec CExpr (TargetOb o), PolyVec C i) =>
Text ->
(i `C.Cat` o) ->
Q [Dec]
embedFunction name f = do
-- generate C FFI
let cname = "c_" <> name
cnameName = TH.mkName (T.unpack cname)
codeC <-
TH.runIO $ do
x <- generateCExprFunction name (inputDims $ Proxy @i) (arraysFun f)
case x of
Left err -> Exception.impureThrow err
Right (CExpr.FunctionText _ srcText) ->
pure $ Prettyprint.renderStrict $ CExpr.layoutOptions srcText
TH.addForeignSource LangC (T.unpack codeC)
cfunFfi <-
ForeignD . ImportF CCall Safe (T.unpack name) cnameName <$> [t|SBVFunCall|]
-- generate high-level haskell
let inputTy = ConT (TH.mkName (getTypeName (Proxy @i)))
outputTy = ConT (TH.mkName (getTypeName (Proxy @o)))
funName = TH.mkName (T.unpack ("hs_" <> name))
hsfunSig <-
SigD funName <$> [t|$(pure inputTy) -> IO $(pure outputTy)|]
body <-
[|fromArraysCC (Proxy @($(pure inputTy) -> $(pure outputTy))) $(pure (VarE cnameName)) input|]
let hsfunDef = FunD funName [Clause [VarP (TH.mkName "input")] (NormalB body) []]
--
pure [cfunFfi, hsfunSig, hsfunDef]
2 changes: 2 additions & 0 deletions test-lib/categorifier-c-test-lib.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ library
Categorifier.C.Codegen.FFI.Call
Categorifier.C.Codegen.FFI.Foreign
Categorifier.C.Codegen.FFI.Spec
Categorifier.C.Codegen.FFI.TH
Categorifier.C.KGen.KGen
Categorifier.C.KGen.TH
Categorifier.C.KGenGenerate.FFI.Bindings
Expand Down Expand Up @@ -51,6 +52,7 @@ library
, haskell-src-exts
, hedgehog
, lens
, prettyprinter
, process
, sbv
, serialise
Expand Down

0 comments on commit 60b73bc

Please sign in to comment.