Created
April 18, 2014 18:52
-
-
Save pxqr/11058960 to your computer and use it in GitHub Desktop.
gpu k-means
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RecordWildCards #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# LANGUAGE TypeOperators #-} | |
module Data.Clustering.KMeans | |
( -- * Types | |
Points | |
, Centroids | |
, ClusterIx | |
-- * Routines | |
, kmeans | |
, vq | |
, loss | |
-- * Options | |
, K | |
, InitialCentroids (..) | |
, Options (..) | |
) where | |
import Control.Monad.ST | |
import Data.Array.Accelerate as A | |
import Data.Default | |
import Data.Vector as V hiding (Vector) | |
import qualified Data.Vector as V (Vector) | |
import System.Random.MWC as MWC | |
type Points a = Array DIM2 a | |
type Centroids a = Array DIM2 a | |
type ClusterIx = Int | |
type Distances a = Array DIM2 a | |
type Codebook a = [(a, Int)] | |
{----------------------------------------------------------------------- | |
-- Initialization | |
-----------------------------------------------------------------------} | |
-- TODO implement kmeans++ | |
-- TODO implement kmeans|| | |
-- | Number of centroids for kmeans algorithm. | |
type K = Int | |
-- | Intialization algorithms. | |
data InitialCentroids a | |
= CustomCenters (Acc (Centroids a)) | |
| RandomPoints K | |
| KMeansPlusPlus K | |
| KMeansPar K | |
-- | Five points selected from the dataset by random. | |
instance Default (InitialCentroids a) where | |
def = RandomPoints 5 | |
-- precondition: length of vector must be no more than number of rows in matrix | |
gatherRows :: (Elt a, IsNum a) | |
=> Acc (Vector Int) -> Acc (Array DIM2 a) -> Acc (Array DIM2 a) | |
gatherRows disp rs = A.backpermute sh f rs | |
where | |
sh = lift (Z :. rows :. cols) | |
f (unlift -> (Z :. (u :: Exp Int) :. (v :: Exp Int))) = index2 (disp A.! index1 u) v | |
Z :. _ :. cols = unlift (shape rs) :: Z :. Exp Int :. Exp Int | |
Z :. rows = unlift (shape disp) :: Z :. Exp Int | |
-- | Convert Data.Vector to Data.Array.Accelerate.Vector. | |
fromVector :: Elt a => V.Vector a -> Vector a | |
fromVector v = fromFunction (Z :. V.length v) f | |
where | |
f (Z :. (i :: Int)) = v V.! i | |
-- TODO filter duplicates | |
genDistinct :: K -> Vector Int | |
genDistinct k = fromVector $ runST $ do | |
g <- MWC.create -- TODO external seed | |
V.replicateM k $ uniformR (0, k - 1) g | |
initialCenters :: (Elt a, IsNum a) | |
=> InitialCentroids a -> Acc (Points a) -> Acc (Centroids a) | |
initialCenters (CustomCenters cs) _ = cs | |
initialCenters (RandomPoints k ) ps = gatherRows (use (genDistinct k)) ps | |
initialCenters (KMeansPlusPlus k ) ps = undefined | |
initialCenters (KMeansPar k ) ps = undefined | |
{----------------------------------------------------------------------- | |
-- Optimization | |
-----------------------------------------------------------------------} | |
-- TODO remove non-updatable centroids | |
-- TODO remove duplicate centroids | |
sqr :: (Elt a, IsNum a) => Exp a -> Exp a | |
sqr x = x * x | |
-- | Find distances between a set of points and a set of clusters. | |
distances :: (Elt a, IsNum a) | |
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Distances a) | |
distances ps cs = A.fold (+) 0 $ A.map sqr $ A.zipWith (-) pRepl cRepl | |
where | |
pRepl = A.replicate (lift (Z :. All :. centers :. All)) ps | |
cRepl = A.replicate (lift (Z :. points :. All :. All)) cs | |
Z :. centers :. _ = unlift (A.shape cs) :: Z :. Exp Int :. Exp Int | |
Z :. points :. _ = unlift (A.shape ps) :: Z :. Exp Int :. Exp Int | |
-- | Index each element by index of the outer dimension. | |
indexOuter :: (Elt a, Shape sh, Slice sh) | |
=> Acc (Array (sh :. Int) a) -> Acc (Array (sh :. Int) (Int, a)) | |
indexOuter v = A.zipWith pair (A.generate (A.shape v) indexHead) v | |
where | |
pair a b = lift (a, b) | |
-- | Find the nearest cluster (index, distance) for each data point. | |
nearest :: forall a. (Elt a, IsNum a, IsScalar a) | |
=> Acc (Distances a) -> Acc (Vector (Int, a)) | |
nearest = A.fold1 minPair . indexOuter | |
where | |
minPair a b = x <* x' ? (lift (i, x), lift (i', x')) | |
where | |
(i, x ) = unlift a :: (Exp Int, Exp a) | |
(i', x') = unlift b :: (Exp Int, Exp a) | |
nearestIdx :: (Elt a, IsNum a, IsScalar a) => Acc (Distances a) -> Acc (Vector Int) | |
nearestIdx = A.map A.fst . nearest | |
nearestDistance :: (Elt a, IsNum a, IsScalar a) => Acc (Distances a) -> Acc (Vector a) | |
nearestDistance = A.map A.snd . nearest | |
filterRow :: Elt a => (Exp a -> Exp Bool) -> Acc (Array DIM2 a) -> Acc (Array DIM2 a) | |
filterRow f xss = A.reshape (lift (Z :. rows :. cols)) xs | |
where | |
Z :. _ :. cols = unlift (A.shape xss) :: Z :. Exp Int :. Exp Int | |
rows = A.size xs `div` cols | |
xs = A.filter f $ A.reshape (index1 (A.size xss)) xss | |
centroids :: (Elt a, IsNum a, IsFloating a) | |
=> Exp K -> Acc (Points a) -> Acc (Vector Int) -> Acc (Centroids a) | |
centroids k ps idx | |
= A.map (A.uncurry (/)) $ filterRow ((/=*) 0 . A.snd) | |
$ A.permute addPair z f $ A.map consOne ps | |
where | |
z = A.generate (lift (Z :. k :. features)) | |
(const (lift (constant 0, constant 0))) | |
Z :. _points :. features = unlift (A.shape ps) :: Z :. Exp Int :. Exp Int | |
f (unlift -> (Z :. (u :: Exp Int) :. (v :: Exp Int))) | |
= index2 (idx A.! index1 u) v | |
consOne :: forall a. (Elt a, IsNum a) => Exp a -> Exp (a, a) | |
consOne x = lift (x, 1 :: Exp a) | |
addPair :: forall a. (Elt a, IsNum a) => Exp (a, a) -> Exp (a, a) -> Exp (a, a) | |
addPair a b = lift (as + bs, an + bn) | |
where | |
(as, an) = unlift a :: (Exp a, Exp a) | |
(bs, bn) = unlift b :: (Exp a, Exp a) | |
iteration :: (Elt a, IsNum a, IsScalar a, IsFloating a) | |
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Centroids a) | |
iteration ps cs = centroids k ps (nearestIdx (distances ps cs)) | |
where | |
k = indexHead (A.shape cs) | |
-- | Within-cluster sum of squares. | |
loss :: (Elt a, IsNum a) => Acc (Points a) -> Acc (Centroids a) -> Acc (Scalar a) | |
loss ps cs = A.sum (nearestDistance (distances ps cs)) | |
------------------------------------------------------------------------ | |
-- TODO multiple runs with increasing K | |
data Options a = Options | |
{ optCentering :: InitialCentroids a | |
, optLimit :: Int -- ^ Max number of Lloyd iterations. | |
, optThresh :: a -- ^ Minimum distance between iterations. | |
} | |
instance Num a => Default (Options a) where | |
def = Options def 100 0 | |
-- current step number, previous centers, current centers | |
type KState a = (Scalar Int, Centroids a, Centroids a) | |
viewState :: Elt a => Acc (KState a) | |
-> (Acc (Scalar Int), Acc (Centroids a), Acc (Centroids a)) | |
viewState = unlift | |
whiten :: Acc (Points a) -> Acc (Points a) | |
whiten = undefined | |
-- TODO use threshold | |
-- | Find 'K' cluster centers by minimizing the Euclidian distance | |
-- between points and centroids. | |
kmeans :: (Elt a, IsNum a, IsScalar a, IsFloating a) | |
=> Options a -> Acc (Points a) -> Acc (Centroids a) | |
kmeans Options {..} points = result (A.awhile needContinue step initialState) | |
where | |
initialState = let cs = initialCenters optCentering points | |
in lift (unit (constant 0), cs, cs) | |
step (viewState -> (i, _, c)) = lift (A.map (+1) i, c, iteration points c) | |
needContinue (viewState -> (i, p, c)) = | |
unit ((the i <* constant optLimit)) | |
-- &&* A.not (the (A.and (A.zipWith (==*) p c)))) | |
result (viewState -> (_, _, c)) = c | |
-- | Vector quantization: each point is compared with the centroids in | |
-- the centroids array and assigned the index of the closest centroid. | |
vq :: (Elt a, IsNum a, IsScalar a) | |
=> Acc (Points a) -> Acc (Centroids a) -> Acc (Vector ClusterIx) | |
vq ps cs = nearestIdx (distances ps cs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment