Created
September 7, 2019 00:51
-
-
Save mstksg/bfa4a1bcb0ca883eb1bd08dd4324737a to your computer and use it in GitHub Desktop.
ID3 homework assignment from 2015
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Instructions for running | |
-- | |
-- If GHC haskell is installed: | |
-- | |
-- $ runhaskell id3.hs | |
-- | |
-- Have fun! | |
-- | |
-- On some occasions, 'containers' package might need to be installed; | |
-- in that case, install with | |
-- | |
-- $ cabal install containers | |
-- | |
module Main where | |
import Data.Function | |
import Data.List | |
import Data.Maybe | |
import Data.Ord | |
import Data.Tree | |
import Prelude | |
import qualified Data.Map as M | |
data DataPoint = DP { outlook :: String | |
, temperature :: String | |
, humidity :: String | |
, wind :: String | |
, toPlay :: Bool | |
} deriving Show | |
data Attribute = Atr { atrName :: String | |
, atrGet :: GetAttribute | |
} | |
type GetAttribute = DataPoint -> String | |
instance Eq Attribute where (==) = on (==) atrName | |
attributes :: [Attribute] | |
attributes = [ Atr "outlook" outlook | |
, Atr "temperature" temperature | |
, Atr "humidity" humidity | |
, Atr "wind" wind | |
] | |
-- | Parse a list of strings into a DataPoin. Possibly. | |
toDataPoint :: [String] -> Maybe DataPoint | |
toDataPoint (o:t:h:w:c:_) = Just $ DP o t h w (c == "p") | |
toDataPoint _ = Nothing | |
-- | Entropy of a given data set, on its class variable | |
dpEntropy :: [DataPoint] -> Double | |
dpEntropy = entropy . getProb | |
where | |
getProb xs = genericLength (filter toPlay xs) | |
/ genericLength xs | |
entropy x = ex x + ex (1-x) | |
ex 0 = 0 -- special case for x == 0 | |
ex x = -logBase 2 x * x | |
-- | Calculate the expected entropy gain generated by splitting a data set | |
-- on the given attribute | |
entropyGain :: [DataPoint] -> Attribute -> Double | |
entropyGain dps atr = initialEntropy - expFinalEntropy | |
where | |
initialEntropy = dpEntropy dps | |
expFinalEntropy = (/ genericLength dps) | |
. sum | |
. map (\(_, vdps) -> genericLength vdps * dpEntropy vdps) | |
. splitOnAtr atr | |
$ dps | |
-- | Split a set of data points on an attribute, as Attribute value - data | |
-- set pairs | |
splitOnAtr :: Attribute -> [DataPoint] -> [(String, [DataPoint])] | |
splitOnAtr atr = M.toList | |
. M.fromListWith (++) | |
. map (\dp -> (atrGet atr dp, [dp])) | |
-- | build the decision tree with a list of remaining attributes, a "label" | |
-- prefix (current choice), and a set of data points. | |
buildTree :: [Attribute] -> String -> [DataPoint] -> Tree String | |
-- base case: attributes exhausted | |
buildTree [] labl dps = Node (labl ++ ": " ++ favoring) [] | |
where | |
playCount = length . filter toPlay $ dps | |
noPlayCount = length . filter (not . toPlay) $ dps | |
favoring | playCount >= noPlayCount = "Play!" | |
| otherwise = "No Play" | |
buildTree atrs labl dps = Node nodeName subTrees | |
where | |
bestAtr = maximumBy (comparing (entropyGain dps)) atrs | |
nodeName = labl ++ ": check " ++ atrName bestAtr | |
badAtrs = filter (/= bestAtr) atrs | |
splitDps = splitOnAtr bestAtr dps | |
subTrees = flip map splitDps $ | |
\(labl',dps') -> | |
if all toPlay dps' || not (any toPlay dps') | |
-- skip to end if unambiguous | |
then buildTree [] labl' dps' | |
-- otherwise, recurse | |
else buildTree badAtrs labl' dps' | |
dataset :: [[String]] | |
dataset = [ [ "sunny" , "hot" , "high" , "weak" , "n" ] | |
, [ "sunny" , "hot" , "high" , "strong", "n" ] | |
, [ "overcast", "hot" , "high" , "weak" , "p" ] | |
, [ "rain" , "mild", "high" , "weak" , "p" ] | |
, [ "rain" , "cool", "normal", "weak" , "p" ] | |
, [ "rain" , "cool", "normal", "strong", "n" ] | |
, [ "overcast", "cool", "normal", "strong", "p" ] | |
, [ "sunny" , "mild", "high" , "weak" , "n" ] | |
, [ "sunny" , "cool", "normal", "weak" , "p" ] | |
, [ "rain" , "mild", "normal", "weak" , "p" ] | |
, [ "sunny" , "mild", "normal", "strong", "p" ] | |
, [ "overcast", "mild", "high" , "strong", "p" ] | |
, [ "overcast", "hot" , "normal", "weak" , "p" ] | |
, [ "rain" , "mild", "high" , "strong", "n" ] | |
] | |
main :: IO () | |
main = putStrLn | |
. drawTree | |
. buildTree attributes "top" | |
. mapMaybe toDataPoint | |
$ dataset | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment