Skip to content

Instantly share code, notes, and snippets.

@pxqr
Created April 18, 2014 18:52
Show Gist options
  • Save pxqr/11058960 to your computer and use it in GitHub Desktop.
Save pxqr/11058960 to your computer and use it in GitHub Desktop.
gpu k-means
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeOperators #-}
module Data.Clustering.KMeans
( -- * Types
Points
, Centroids
, ClusterIx
-- * Routines
, kmeans
, vq
, loss
-- * Options
, K
, InitialCentroids (..)
, Options (..)
) where
import Control.Monad.ST
import Data.Array.Accelerate as A
import Data.Default
import Data.Vector as V hiding (Vector)
import qualified Data.Vector as V (Vector)
import System.Random.MWC as MWC
type Points a = Array DIM2 a
type Centroids a = Array DIM2 a
type ClusterIx = Int
type Distances a = Array DIM2 a
type Codebook a = [(a, Int)]
{-----------------------------------------------------------------------
-- Initialization
-----------------------------------------------------------------------}
-- TODO implement kmeans++
-- TODO implement kmeans||
-- | Number of centroids for kmeans algorithm.
type K = Int
-- | Intialization algorithms.
data InitialCentroids a
= CustomCenters (Acc (Centroids a))
| RandomPoints K
| KMeansPlusPlus K
| KMeansPar K
-- | Five points selected from the dataset by random.
instance Default (InitialCentroids a) where
def = RandomPoints 5
-- precondition: length of vector must be no more than number of rows in matrix
gatherRows :: (Elt a, IsNum a)
=> Acc (Vector Int) -> Acc (Array DIM2 a) -> Acc (Array DIM2 a)
gatherRows disp rs = A.backpermute sh f rs
where
sh = lift (Z :. rows :. cols)
f (unlift -> (Z :. (u :: Exp Int) :. (v :: Exp Int))) = index2 (disp A.! index1 u) v
Z :. _ :. cols = unlift (shape rs) :: Z :. Exp Int :. Exp Int
Z :. rows = unlift (shape disp) :: Z :. Exp Int
-- | Convert Data.Vector to Data.Array.Accelerate.Vector.
fromVector :: Elt a => V.Vector a -> Vector a
fromVector v = fromFunction (Z :. V.length v) f
where
f (Z :. (i :: Int)) = v V.! i
-- TODO filter duplicates
genDistinct :: K -> Vector Int
genDistinct k = fromVector $ runST $ do
g <- MWC.create -- TODO external seed
V.replicateM k $ uniformR (0, k - 1) g
initialCenters :: (Elt a, IsNum a)
=> InitialCentroids a -> Acc (Points a) -> Acc (Centroids a)
initialCenters (CustomCenters cs) _ = cs
initialCenters (RandomPoints k ) ps = gatherRows (use (genDistinct k)) ps
initialCenters (KMeansPlusPlus k ) ps = undefined
initialCenters (KMeansPar k ) ps = undefined
{-----------------------------------------------------------------------
-- Optimization
-----------------------------------------------------------------------}
-- TODO remove non-updatable centroids
-- TODO remove duplicate centroids
sqr :: (Elt a, IsNum a) => Exp a -> Exp a
sqr x = x * x
-- | Find distances between a set of points and a set of clusters.
distances :: (Elt a, IsNum a)
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Distances a)
distances ps cs = A.fold (+) 0 $ A.map sqr $ A.zipWith (-) pRepl cRepl
where
pRepl = A.replicate (lift (Z :. All :. centers :. All)) ps
cRepl = A.replicate (lift (Z :. points :. All :. All)) cs
Z :. centers :. _ = unlift (A.shape cs) :: Z :. Exp Int :. Exp Int
Z :. points :. _ = unlift (A.shape ps) :: Z :. Exp Int :. Exp Int
-- | Index each element by index of the outer dimension.
indexOuter :: (Elt a, Shape sh, Slice sh)
=> Acc (Array (sh :. Int) a) -> Acc (Array (sh :. Int) (Int, a))
indexOuter v = A.zipWith pair (A.generate (A.shape v) indexHead) v
where
pair a b = lift (a, b)
-- | Find the nearest cluster (index, distance) for each data point.
nearest :: forall a. (Elt a, IsNum a, IsScalar a)
=> Acc (Distances a) -> Acc (Vector (Int, a))
nearest = A.fold1 minPair . indexOuter
where
minPair a b = x <* x' ? (lift (i, x), lift (i', x'))
where
(i, x ) = unlift a :: (Exp Int, Exp a)
(i', x') = unlift b :: (Exp Int, Exp a)
nearestIdx :: (Elt a, IsNum a, IsScalar a) => Acc (Distances a) -> Acc (Vector Int)
nearestIdx = A.map A.fst . nearest
nearestDistance :: (Elt a, IsNum a, IsScalar a) => Acc (Distances a) -> Acc (Vector a)
nearestDistance = A.map A.snd . nearest
filterRow :: Elt a => (Exp a -> Exp Bool) -> Acc (Array DIM2 a) -> Acc (Array DIM2 a)
filterRow f xss = A.reshape (lift (Z :. rows :. cols)) xs
where
Z :. _ :. cols = unlift (A.shape xss) :: Z :. Exp Int :. Exp Int
rows = A.size xs `div` cols
xs = A.filter f $ A.reshape (index1 (A.size xss)) xss
centroids :: (Elt a, IsNum a, IsFloating a)
=> Exp K -> Acc (Points a) -> Acc (Vector Int) -> Acc (Centroids a)
centroids k ps idx
= A.map (A.uncurry (/)) $ filterRow ((/=*) 0 . A.snd)
$ A.permute addPair z f $ A.map consOne ps
where
z = A.generate (lift (Z :. k :. features))
(const (lift (constant 0, constant 0)))
Z :. _points :. features = unlift (A.shape ps) :: Z :. Exp Int :. Exp Int
f (unlift -> (Z :. (u :: Exp Int) :. (v :: Exp Int)))
= index2 (idx A.! index1 u) v
consOne :: forall a. (Elt a, IsNum a) => Exp a -> Exp (a, a)
consOne x = lift (x, 1 :: Exp a)
addPair :: forall a. (Elt a, IsNum a) => Exp (a, a) -> Exp (a, a) -> Exp (a, a)
addPair a b = lift (as + bs, an + bn)
where
(as, an) = unlift a :: (Exp a, Exp a)
(bs, bn) = unlift b :: (Exp a, Exp a)
iteration :: (Elt a, IsNum a, IsScalar a, IsFloating a)
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Centroids a)
iteration ps cs = centroids k ps (nearestIdx (distances ps cs))
where
k = indexHead (A.shape cs)
-- | Within-cluster sum of squares.
loss :: (Elt a, IsNum a) => Acc (Points a) -> Acc (Centroids a) -> Acc (Scalar a)
loss ps cs = A.sum (nearestDistance (distances ps cs))
------------------------------------------------------------------------
-- TODO multiple runs with increasing K
data Options a = Options
{ optCentering :: InitialCentroids a
, optLimit :: Int -- ^ Max number of Lloyd iterations.
, optThresh :: a -- ^ Minimum distance between iterations.
}
instance Num a => Default (Options a) where
def = Options def 100 0
-- current step number, previous centers, current centers
type KState a = (Scalar Int, Centroids a, Centroids a)
viewState :: Elt a => Acc (KState a)
-> (Acc (Scalar Int), Acc (Centroids a), Acc (Centroids a))
viewState = unlift
whiten :: Acc (Points a) -> Acc (Points a)
whiten = undefined
-- TODO use threshold
-- | Find 'K' cluster centers by minimizing the Euclidian distance
-- between points and centroids.
kmeans :: (Elt a, IsNum a, IsScalar a, IsFloating a)
=> Options a -> Acc (Points a) -> Acc (Centroids a)
kmeans Options {..} points = result (A.awhile needContinue step initialState)
where
initialState = let cs = initialCenters optCentering points
in lift (unit (constant 0), cs, cs)
step (viewState -> (i, _, c)) = lift (A.map (+1) i, c, iteration points c)
needContinue (viewState -> (i, p, c)) =
unit ((the i <* constant optLimit))
-- &&* A.not (the (A.and (A.zipWith (==*) p c))))
result (viewState -> (_, _, c)) = c
-- | Vector quantization: each point is compared with the centroids in
-- the centroids array and assigned the index of the closest centroid.
vq :: (Elt a, IsNum a, IsScalar a)
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Vector ClusterIx)
vq ps cs = nearestIdx (distances ps cs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment