-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'wavewave-flakes-clang-TH' into wavewave-flakes-clang-2
- Loading branch information
Showing
7 changed files
with
223 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
{-# LANGUAGE DeriveGeneric #-} | ||
{-# LANGUAGE MultiParamTypeClasses #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE TypeFamilies #-} | ||
{-# LANGUAGE TypeOperators #-} | ||
{-# LANGUAGE UndecidableInstances #-} | ||
|
||
module F | ||
( Input (..), | ||
Output (..), | ||
wrap_f, | ||
) | ||
where | ||
|
||
import qualified Categorifier.C.CExpr.Cat as C | ||
import Categorifier.C.CExpr.Cat.TargetOb (TargetOb) | ||
import Categorifier.C.CTypes.CGeneric (CGeneric) | ||
import qualified Categorifier.C.CTypes.CGeneric as CG | ||
import Categorifier.C.CTypes.GArrays (GArrays) | ||
import Categorifier.C.KTypes.C (C) | ||
import qualified Categorifier.Categorify as Categorify | ||
import Categorifier.Client (deriveHasRep) | ||
import Data.Int (Int32) | ||
import Data.Word (Word64) | ||
import GHC.Generics (Generic) | ||
|
||
data Input = Input | ||
{ iInt32 :: C Int32, | ||
iDouble :: C Double | ||
} | ||
deriving (Generic, Show) | ||
|
||
deriveHasRep ''Input | ||
|
||
instance CGeneric Input | ||
|
||
instance GArrays C Input | ||
|
||
type instance TargetOb Input = TargetOb (CG.Rep Input ()) | ||
|
||
data Output = Output | ||
{ oWord64 :: C Word64, | ||
oFloat :: C Float, | ||
oBool :: Bool | ||
} | ||
deriving (Generic, Show) | ||
|
||
deriveHasRep ''Output | ||
|
||
instance CGeneric Output | ||
|
||
instance GArrays C Output | ||
|
||
type instance TargetOb Output = TargetOb (CG.Rep Output ()) | ||
|
||
f :: Input -> Output | ||
f inp = | ||
Output | ||
{ oWord64 = if odd (iInt32 inp) then fromIntegral (iInt32 inp) + 5 else 42, | ||
oFloat = realToFrac $ min (iDouble inp) 3.14, | ||
oBool = iDouble inp > 0 | ||
} | ||
|
||
$(Categorify.function 'f [t|C.Cat|] []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{-# LANGUAGE OverloadedStrings #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE TypeApplications #-} | ||
|
||
module Main (main) where | ||
|
||
import Categorifier.C.Codegen.FFI.TH (embedFunction) | ||
import F (Input (..), Output (..), wrap_f) | ||
|
||
$(embedFunction "simple_example" wrap_f) | ||
|
||
main :: IO () | ||
main = do | ||
let input = Input 1 5.0 | ||
output <- hs_simple_example input | ||
print output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
{-# LANGUAGE FlexibleContexts #-} | ||
{-# LANGUAGE OverloadedStrings #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE TypeApplications #-} | ||
{-# LANGUAGE TypeOperators #-} | ||
-- Data.Text.Prettyprint.Doc.Render.Text is deprecated. | ||
{-# OPTIONS_GHC -fno-warn-deprecations #-} | ||
|
||
module Categorifier.C.Codegen.FFI.TH (embedFunction) where | ||
|
||
import qualified Categorifier.C.CExpr.Cat as C | ||
import Categorifier.C.CExpr.Cat.TargetOb (TargetOb) | ||
import qualified Categorifier.C.CExpr.File as CExpr (FunctionText (..)) | ||
import qualified Categorifier.C.CExpr.IO as CExpr (layoutOptions) | ||
import Categorifier.C.CExpr.Types.Core (CExpr) | ||
import Categorifier.C.Codegen.FFI.ArraysCC (fromArraysCC) | ||
import Categorifier.C.Codegen.FFI.Spec (SBVFunCall) | ||
import Categorifier.C.KTypes.C (C) | ||
import Categorifier.C.KTypes.CExpr.Generate (generateCExprFunction) | ||
import Categorifier.C.PolyVec (PolyVec, pdevectorize, pvectorize, pvlengths) | ||
import Categorifier.C.Prim (ArrayCount, Arrays) | ||
import qualified Categorifier.Common.IO.Exception as Exception | ||
import Control.Monad ((<=<)) | ||
import Data.Functor.Compose (Compose (..)) | ||
import Data.Proxy (Proxy (..)) | ||
import Data.Text (Text) | ||
import qualified Data.Text as T | ||
import qualified Data.Text.Prettyprint.Doc.Render.Text as Prettyprint | ||
import Data.Typeable (Typeable) | ||
import Data.Vector (Vector) | ||
import Language.Haskell.TH.Syntax | ||
( Body (NormalB), | ||
Callconv (..), | ||
Clause (..), | ||
Dec (ForeignD, FunD, SigD), | ||
Exp (..), | ||
Foreign (ImportF), | ||
ForeignSrcLang (LangC), | ||
Pat (VarP), | ||
Q, | ||
Safety (Safe), | ||
Type (..), | ||
) | ||
import qualified Language.Haskell.TH.Syntax as TH | ||
import qualified Type.Reflection as TR | ||
|
||
arraysFun :: | ||
forall i o. | ||
(PolyVec CExpr (TargetOb i), PolyVec CExpr (TargetOb o)) => | ||
(i `C.Cat` o) -> | ||
Arrays (Compose Vector CExpr) -> | ||
IO (Arrays (Compose Vector CExpr)) | ||
arraysFun f = | ||
Exception.throwIOLeft . pvectorize . C.lowerCat f <=< Exception.throwIOLeft . pdevectorize | ||
|
||
inputDims :: forall a. PolyVec C a => Proxy a -> Arrays ArrayCount | ||
inputDims = pvlengths (Proxy @C) | ||
|
||
getTypeName :: forall t. (Typeable t) => Proxy t -> String | ||
getTypeName p = | ||
let tRep = TR.someTypeRep p -- (Proxy @t) | ||
tCon = TR.someTypeRepTyCon tRep | ||
in TR.tyConName tCon | ||
|
||
embedFunction :: | ||
forall i o. | ||
(Typeable i, Typeable o, PolyVec CExpr (TargetOb i), PolyVec CExpr (TargetOb o), PolyVec C i) => | ||
Text -> | ||
(i `C.Cat` o) -> | ||
Q [Dec] | ||
embedFunction name f = do | ||
-- generate C FFI | ||
let cname = "c_" <> name | ||
cnameName = TH.mkName (T.unpack cname) | ||
codeC <- | ||
TH.runIO $ do | ||
x <- generateCExprFunction name (inputDims $ Proxy @i) (arraysFun f) | ||
case x of | ||
Left err -> Exception.impureThrow err | ||
Right (CExpr.FunctionText _ srcText) -> | ||
pure $ Prettyprint.renderStrict $ CExpr.layoutOptions srcText | ||
TH.addForeignSource LangC (T.unpack codeC) | ||
cfunFfi <- | ||
ForeignD . ImportF CCall Safe (T.unpack name) cnameName <$> [t|SBVFunCall|] | ||
-- generate high-level haskell | ||
let inputTy = ConT (TH.mkName (getTypeName (Proxy @i))) | ||
outputTy = ConT (TH.mkName (getTypeName (Proxy @o))) | ||
funName = TH.mkName (T.unpack ("hs_" <> name)) | ||
hsfunSig <- | ||
SigD funName <$> [t|$(pure inputTy) -> IO $(pure outputTy)|] | ||
body <- | ||
[|fromArraysCC (Proxy @($(pure inputTy) -> $(pure outputTy))) $(pure (VarE cnameName)) input|] | ||
let hsfunDef = FunD funName [Clause [VarP (TH.mkName "input")] (NormalB body) []] | ||
-- | ||
pure [cfunFfi, hsfunSig, hsfunDef] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters