Skip to content

Instantly share code, notes, and snippets.

@pxqr
Created March 3, 2014 14:12
Show Gist options
  • Save pxqr/9325692 to your computer and use it in GitHub Desktop.
Save pxqr/9325692 to your computer and use it in GitHub Desktop.
{-# LANGUAGE ExistentialQuantification, DefaultSignatures #-}
module Control.Parallel.CL.Kernel.Arg where
import Control.Parallel.OpenCL
import Control.Monad.CL
import Foreign
data ArgW a = Arg a => MkArg a
class Arg a where
setArg :: CLKernel -> (CLuint, a) -> CL ()
default setArg :: Storable a => CLKernel -> (CLuint, a) -> CL ()
setArg k (i, a) = liftIO $ clSetKernelArgSto k i a
instance Arg (ArgW a) where
setArg k (i, (MkArg a))= setArg k (i, a)
module Control.Parallel.CL.Buffer
( Buffer
, createBuffer, releaseBuffer
, writeBuffer, readBuffer
, useBuffer
, bufferSize
) where
import Control.Parallel.OpenCL
import Control.Monad.CL
import Control.Parallel.CL.Kernel.Arg
import Data.Vector.Storable as V
import Foreign
{-
newtype Buffer = Buffer { getVector :: Vector Int8 }
deriving Arg
-- | Consider vector as buffer of bytes.
fromVector :: Storable a => Vector a -> Buffer
fromVector = Buffer . unsafeCastVector
-- | Convert buffer back to vector.
-- Note that this function is completely unsafe:
-- if type of elements of buffer do not match with requred type
-- then really bad things may happen. (undefined behaviour)
--
unsafeBufferToVector :: Storable a => Buffer -> Vector a
unsafeBufferToVector = unsafeCastVector . getVector
-}
type Size = Int
data Buffer = Buffer {
bufHandle :: {-# UNPACK #-} !CLMem
, bufLen :: {-# UNPACK #-} !Size
, bufElemSize :: {-# UNPACK #-} !Size -- mostly for debugging
} deriving (Show, Eq, Ord)
instance Arg Buffer where
setArg k (i, buf) = liftIO $ clSetKernelArgSto k i (bufHandle buf)
bufferSize :: Buffer -> Size
bufferSize (Buffer _ len elemSize) = len * elemSize
vectorSize :: Storable a => Vector a -> Size
vectorSize vec = V.length vec * vectorElemSize vec
vectorElemSize :: Storable a => Vector a -> Size
vectorElemSize = sizeOf . tie
where
tie :: v a -> a
tie _ = undefined
createBuffer :: Size -> Size -> CL Buffer
createBuffer len elemSize = do
cxt <- context
liftIO $ do
mem <- clCreateBuffer cxt [] (elemSize * len, nullPtr)
return $ Buffer mem len elemSize
writeBuffer :: Storable a => Vector a -> Buffer -> CL ()
writeBuffer vec buf = do
q <- commandQ
let size = vectorSize vec
if vectorSize vec /= bufferSize buf then error "writeBuffer" else return () -- <ASSERT>
liftIO $ unsafeWith vec $ \ptr ->
clEnqueueWriteBuffer q (bufHandle buf) True 0 size (castPtr ptr) []
return ()
useBuffer :: Storable a => Vector a -> CL Buffer
useBuffer vec = do
buf <- createBuffer (V.length vec) (vectorElemSize vec)
writeBuffer vec buf
return buf
readBuffer :: Storable a => Buffer -> CL (Vector a)
readBuffer buf = do
q <- commandQ
let len = bufLen buf
let size = bufferSize buf
liftIO $ do
fptr <- mallocForeignPtrBytes len
withForeignPtr fptr $ \ptr ->
clEnqueueReadBuffer q
(bufHandle buf)
True
0
size
(castPtr ptr)
[]
return $ V.unsafeFromForeignPtr fptr 0 len
releaseBuffer :: Buffer -> CL Bool
releaseBuffer = liftIO . clReleaseMemObject . bufHandle
module Control.Parallel.CL
( module Control.Monad.CL
, module Control.Parallel.CL.Machine
) where
import Control.Monad.CL
import Control.Parallel.CL.Machine
module Control.Parallel.CL.Kernel
( Kernel, Source, Grid, Arg(..), ArgW(..)
, createProgram, releaseProgram
, createKernel, setKernel, runKernel , releaseKernel
, kernel
) where
import Control.Monad.CL
import Control.Parallel.OpenCL
import Control.Parallel.CL.Kernel.Arg
import Control.Monad.Trans
type Source = String
type Kernel = CLKernel
type Program = CLProgram
type Name = String
createProgram :: Source -> CL Program
createProgram src = do
cxt <- context
dev <- device
liftIO $ do
prg <- clCreateProgramWithSource cxt src
clBuildProgram prg [dev] ""
return prg
releaseProgram :: Program -> CL Bool
releaseProgram = liftIO . clReleaseProgram
createKernel :: Program -> Name -> CL Kernel
createKernel prg src = liftIO $ clCreateKernel prg src
releaseKernel :: Kernel -> CL Bool
releaseKernel = liftIO . clReleaseKernel
setKernel :: Arg a => Kernel -> [ArgW a] -> CL ()
setKernel krn = mapM_ (setArg krn) . zip [0..]
type Grid = ([Int], [Int])
runKernel :: Kernel -> Grid -> CL ()
runKernel krn (groups, items) = do
cmdQ <- commandQ
liftIO $ do
clEnqueueNDRangeKernel cmdQ krn groups items []
clFlush cmdQ
return ()
kernel :: Arg a => Source -> Grid -> [ArgW a] -> CL ()
kernel src grid args = do
prg <- createProgram src
krn <- createKernel prg "main"
setKernel krn args
runKernel krn grid
releaseKernel krn
return ()
--(<.>) :: Arg a => CL (Kernel (a -> b)) -> a -> CL ()
--(<.>) = undefined
--kernel source <.> input <.> output
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
{-# LANGUAGE ExistentialQuantification #-}
module Control.Parallel.CL.Machine
( Instruction(..), Usable(..)
, VProgram(..), Blob
, execProgram, readReg
, evalMachine
) where
import Control.Monad.State
import Control.Monad.CL
import Control.Parallel.CL.Buffer
import Control.Parallel.CL.Kernel
import Control.Parallel.OpenCL
import Data.Foldable as F
import Data.Traversable as T
import Data.IntMap as M
import qualified Data.Vector.Storable as V
import Data.Maybe
import Foreign
type Reg = Int
data Usable = forall a. Storable a => MkUsable (V.Vector a)
instance Show Usable where
show = const "usable"
data Instruction krn = Use Reg Usable
| Alloc Reg Int
| Release Reg
| Kernel krn Grid [Reg] Reg
deriving (Show, Functor, Foldable, Traversable)
type Machine = StateT (IntMap Buffer) CL
evalMachine :: Machine a -> CL a
evalMachine m = evalStateT m M.empty
bindReg :: Reg -> Buffer -> Machine ()
bindReg r b = modify (M.insert r b)
readReg :: Storable a => Reg -> Machine (V.Vector a)
readReg reg = gets (fromMaybe fatalError . M.lookup reg) >>= lift . readBuffer
where
fatalError = error "readReg: register not binded."
lookupRegs :: [Reg] -> Machine [Buffer]
lookupRegs regs = do
bindings <- get
return $ fmap (fromMaybe fatalError . (`M.lookup` bindings)) regs
where
fatalError = error "execInstruction: register not binded."
execInstruction :: Instruction Kernel -> Machine ()
execInstruction (Use r (MkUsable vec)) = lift (useBuffer vec) >>= bindReg r
execInstruction (Alloc r s) = lift (createBuffer s 1) >>= bindReg r
execInstruction (Release r) = do
regs <- get
case M.lookup r regs of
Just buf -> do
lift $ releaseBuffer buf
put $ M.delete r regs
Nothing -> return ()
execInstruction (Kernel krn grid inputs output) = do
args <- lookupRegs $ inputs ++ [output]
lift $ do
setKernel krn $ fmap MkArg args
runKernel krn grid
newtype VProgram krn = VProgram { getVProgram :: [Instruction krn] }
deriving (Show, Functor, Foldable, Traversable)
type Name = String
type Src = String
type Blob = Name -> Src
kernelPref :: Name
kernelPref = "kernel_pref_"
assignIndices :: VProgram (Name -> Src) -> VProgram (Src, Name)
assignIndices = snd . mapAccumL assign 0
where
assign ix f = (succ ix, (f $ name ix, name ix))
name i = kernelPref ++ show i
compileVProgram :: VProgram (Src, Name) -> CL (VProgram Kernel)
compileVProgram vprg = do
let src = F.fold $ fmap fst vprg
prg <- createProgram $ src
for vprg (createKernel prg . snd)
releaseVProgram :: VProgram Kernel -> CL ()
releaseVProgram vprg = do
hprg <- liftIO $ clGetKernelProgram $ F.foldr1 const vprg
for vprg releaseKernel
releaseProgram hprg
return ()
withVProgram :: VProgram Blob -> (VProgram Kernel -> Machine a) -> Machine a
withVProgram prgBlob f = do
p <- lift $ compileVProgram (assignIndices prgBlob)
res <- f p
lift $ releaseVProgram p
return res
execProgram :: VProgram Blob -> Machine ()
execProgram vprg = withVProgram vprg (F.mapM_ execInstruction . getVProgram)
name: ocl
version: 0.1.0.0
-- synopsis:
-- description:
license: MIT
license-file: LICENSE
author: Sam T.
maintainer: Sam T. <[email protected]>
-- copyright:
category: Parallel
build-type: Simple
cabal-version: >=1.8
library
exposed-modules: Control.Monad.CL
, Control.Parallel.CL
, Control.Parallel.CL.Vector
, Control.Parallel.CL.Buffer
, Control.Parallel.CL.Kernel
, Control.Parallel.CL.Kernel.Arg
-- other-modules:
build-depends: base == 4.5.*
, containers == 0.4.*
, mtl == 2.1.*
module Control.Parallel.CL.Vector
( Vector, V
, length
, store, load, release
, unsafeCastVector
) where
import Prelude hiding (length)
import Control.Monad.CL
import Control.Monad.Trans
import Control.Parallel.OpenCL
import Control.Parallel.CL.Kernel.Arg
import qualified Data.Vector.Storable as V
import Foreign
type Mem = CLMem
type Size = Int
-- | Vector is proxy or info about mem object which contains elements of type /a/.
--
-- * Fantom type is used to represent types of elements.
--
-- * 'bufHandle' is handle to device buffer.
--
-- * 'bufLen' is actual count of elements in buffer.
--
data Vector a = Vector {
bufHandle :: {-# UNPACK #-} !Mem
, bufLen :: {-# UNPACK #-} !Size
}
type V = Vector
instance Arg (Vector a) where
setArg k (i, (Vector hBuf _)) = liftIO $ clSetKernelArgSto k i hBuf
-- | Allocates mem object sufficient to place vector in and
-- writes a vector to device memory.
load :: Storable a => V.Vector a -> CL (Vector a)
load vec = do
let len = V.length vec
let size = sizeOfHostVec vec
hBuf <- createBuffer size
cmdQ <- commandQ
blockW <- blockingWrite
liftIO $ V.unsafeWith vec $ \ptr ->
clEnqueueWriteBuffer cmdQ hBuf blockW 0 size (castPtr ptr) []
return $ Vector hBuf len
-- | Fetch a vector from device to host.
store :: Storable a => Vector a -> CL (V.Vector a)
store vec = do
cmdQ <- commandQ
let hBuf = bufHandle vec
blockR <- blockingRead
let size = sizeOfDeviceVec vec
liftIO $ do
fptr <- mallocForeignPtrBytes size
withForeignPtr fptr $ \ptr ->
clEnqueueReadBuffer cmdQ hBuf blockR 0 size (castPtr ptr) []
return $ V.unsafeFromForeignPtr fptr 0 (length vec)
-- | Release a vector by freeing all related resources.
release :: Vector a -> CL Bool
release = releaseBuffer . bufHandle
--copy :: Vector a -> Vector a -> CL ()
-- | Count of elements in vector.
length :: Vector a -> Int
length = bufLen
-- | Recompute length of vector and cast type of elements.
unsafeCastVector :: Vector a -> Vector b
unsafeCastVector (Vector h len) = Vector h len'
where
len' = (len * aSize) `div` bSize
aSize = error "castVector"
bSize = error "castVector"
----------------------------------- mem object ---------------------------------
createBuffer :: Size -> CL Mem
createBuffer size = do
cxt <- context
liftIO $ clCreateBuffer cxt [] (size, nullPtr)
releaseBuffer :: Mem -> CL Bool
releaseBuffer mem = liftIO $ clReleaseMemObject mem
----------------------------------- sizes --------------------------------------
-- | Size of either host or device vector in bytes
sizeOfElem :: Storable a => v a -> Size
sizeOfElem = sizeOf . tie
where
tie :: p a -> a
tie _ = undefined
-- | Size of device vector in bytes.
sizeOfDeviceVec :: Storable a => Vector a -> Int
sizeOfDeviceVec v = length v * sizeOfElem v
-- | Size of a host vector in bytes.
sizeOfHostVec :: Storable a => V.Vector a -> Int
sizeOfHostVec vec = sizeOf (tie fptr) * (len - offset)
where
(fptr, offset, len) = V.unsafeToForeignPtr vec
tie :: ForeignPtr a -> a
tie _ = undefined
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment