Example Using Decision Tree using ID3 Algorithm
Created
August 29, 2017 09:05
-
-
Save bonaxcrimo/f690cc471ab1e0f9419fc900c64835cc to your computer and use it in GitHub Desktop.
Decision Tree ID3 Algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<div id="main" class="container"> | |
<h1>ID3 Algorithm</h1> | |
<div id="data-container"> | |
<div id='canvas'></div> | |
<div > | |
<h3>Sample Predictions</h3> | |
<table id='samples' class='table'> | |
</table> | |
<h3>Training Data</h3> | |
<table id='training' class='table'> | |
</table> | |
</div> | |
</div> | |
</div> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//This file contains example training data and samples | |
//some sample data we'll be using | |
//http://www.cise.ufl.edu/~ddd/cap6635/Fall-97/Short-papers/2.htm | |
var examples = [ | |
{day:'D1',outlook:'Sunny', temp:'Hot', humidity:'High', wind: 'Weak',play:'No'}, | |
{day:'D2',outlook:'Sunny', temp:'Hot', humidity:'High', wind: 'Strong',play:'No'}, | |
{day:'D3',outlook:'Overcast', temp:'Hot', humidity:'High', wind: 'Weak',play:'Yes'}, | |
{day:'D4',outlook:'Rain', temp:'Mild', humidity:'High', wind: 'Weak',play:'Yes'}, | |
{day:'D5',outlook:'Rain', temp:'Cool', humidity:'Normal', wind: 'Weak',play:'Yes'}, | |
{day:'D6',outlook:'Rain', temp:'Cool', humidity:'Normal', wind: 'Strong',play:'No'}, | |
{day:'D7',outlook:'Overcast', temp:'Cool', humidity:'Normal', wind: 'Strong',play:'Yes'}, | |
{day:'D8',outlook:'Sunny', temp:'Mild', humidity:'High', wind: 'Weak',play:'No'}, | |
{day:'D9',outlook:'Sunny', temp:'Cool', humidity:'Normal', wind: 'Weak',play:'Yes'}, | |
{day:'D10',outlook:'Rain', temp:'Mild', humidity:'Normal', wind: 'Weak',play:'Yes'}, | |
{day:'D11',outlook:'Sunny', temp:'Mild', humidity:'Normal', wind: 'Strong',play:'Yes'}, | |
{day:'D12',outlook:'Overcast', temp:'Mild', humidity:'High', wind: 'Strong',play:'Yes'}, | |
{day:'D13',outlook:'Overcast', temp:'Hot', humidity:'Normal', wind: 'Weak',play:'Yes'}, | |
{day:'D14',outlook:'Rain', temp:'Mild', humidity:'High', wind: 'Strong',play:'No'} | |
]; | |
examples = _(examples); | |
var features = ['outlook', 'temp', 'humidity', 'wind']; | |
var samples = [{outlook:'Overcast', temp:'Mild', humidity:'High', wind: 'Strong',play: 'Yes'}, | |
{outlook:'Rain', temp:'Mild', humidity:'High', wind: 'Strong', play: 'No'}, | |
{outlook:'Sunny', temp:'Cool', humidity:'Normal', wind: 'Weak', play: 'Yes'}] | |
//ID3 Decision Tree Algorithm | |
//main algorithm and prediction functions | |
var id3 = function(_s,target,features){ | |
var targets = _.unique(_s.pluck(target)); | |
if (targets.length == 1){ | |
console.log("end node! "+targets[0]); | |
return {type:"result", val: targets[0], name: targets[0],alias:targets[0]+randomTag() }; | |
} | |
if(features.length == 0){ | |
console.log("returning the most dominate feature!!!"); | |
var topTarget = mostCommon(_s.pluck(target)); | |
return {type:"result", val: topTarget, name: topTarget, alias: topTarget+randomTag()}; | |
} | |
var bestFeature = maxGain(_s,target,features); | |
var remainingFeatures = _.without(features,bestFeature); | |
var possibleValues = _.unique(_s.pluck(bestFeature)); | |
console.log("node for "+bestFeature); | |
var node = {name: bestFeature,alias: bestFeature+randomTag()}; | |
node.type = "feature"; | |
node.vals = _.map(possibleValues,function(v){ | |
console.log("creating a branch for "+v); | |
var _newS = _(_s.filter(function(x) {return x[bestFeature] == v})); | |
var child_node = {name:v,alias:v+randomTag(),type: "feature_value"}; | |
child_node.child = id3(_newS,target,remainingFeatures); | |
return child_node; | |
}); | |
return node; | |
} | |
var predict = function(id3Model,sample) { | |
var root = id3Model; | |
while(root.type != "result"){ | |
var attr = root.name; | |
var sampleVal = sample[attr]; | |
var childNode = _.detect(root.vals,function(x){return x.name == sampleVal}); | |
root = childNode.child; | |
} | |
return root.val; | |
} | |
//necessary math functions | |
var entropy = function(vals){ | |
var uniqueVals = _.unique(vals); | |
var probs = uniqueVals.map(function(x){return prob(x,vals)}); | |
var logVals = probs.map(function(p){return -p*log2(p) }); | |
return logVals.reduce(function(a,b){return a+b},0); | |
} | |
var gain = function(_s,target,feature){ | |
var attrVals = _.unique(_s.pluck(feature)); | |
var setEntropy = entropy(_s.pluck(target)); | |
var setSize = _s.size(); | |
var entropies = attrVals.map(function(n){ | |
var subset = _s.filter(function(x){return x[feature] === n}); | |
return (subset.length/setSize)*entropy(_.pluck(subset,target)); | |
}); | |
var sumOfEntropies = entropies.reduce(function(a,b){return a+b},0); | |
return setEntropy - sumOfEntropies; | |
} | |
var maxGain = function(_s,target,features){ | |
return _.max(features,function(e){return gain(_s,target,e)}); | |
} | |
var prob = function(val,vals){ | |
var instances = _.filter(vals,function(x) {return x === val}).length; | |
var total = vals.length; | |
return instances/total; | |
} | |
var log2 = function(n){ | |
return Math.log(n)/Math.log(2); | |
} | |
var mostCommon = function(l){ | |
return _.sortBy(l,function(a){ | |
return count(a,l); | |
}).reverse()[0]; | |
} | |
var count = function(a,l){ | |
return _.filter(l,function(b) { return b === a}).length | |
} | |
var randomTag = function(){ | |
return "_r"+Math.round(Math.random()*1000000).toString(); | |
} | |
//Display logic | |
var drawGraph = function(id3Model,divId){ | |
var g = new Array(); | |
g = addEdges(id3Model,g).reverse(); | |
window.g = g; | |
var data = google.visualization.arrayToDataTable(g.concat(g)); | |
var chart = new google.visualization.OrgChart(document.getElementById(divId)); | |
google.visualization.events.addListener(chart, 'ready',function(){ | |
_.each($('.google-visualization-orgchart-node'),function(x){ | |
var oldVal = $(x).html(); | |
if(oldVal){ | |
var cleanVal = oldVal.replace(/_r[0-9]+/,''); | |
$(x).html(cleanVal); | |
} | |
}); | |
}); | |
chart.draw(data, {allowHtml: true}); | |
} | |
var addEdges = function(node,g){ | |
if(node.type == 'feature'){ | |
_.each(node.vals,function(m){ | |
g.push([m.alias,node.alias,'']); | |
g = addEdges(m,g); | |
}); | |
return g; | |
} | |
if(node.type == 'feature_value'){ | |
g.push([node.child.alias,node.alias,'']); | |
if(node.child.type != 'result'){ | |
g = addEdges(node.child,g); | |
} | |
return g; | |
} | |
return g; | |
} | |
var renderSamples = function(samples,$el,model,target,features){ | |
_.each(samples,function(s){ | |
var features_for_sample = _.map(features,function(x){return s[x]}); | |
$el.append("<tr><td>"+features_for_sample.join('</td><td>')+"</td><td><b>"+predict(model,s)+"</b></td><td>actual: "+s[target]+"</td></tr>"); | |
}) | |
} | |
var renderTrainingData = function(_training,$el,target,features){ | |
_training.each(function(s){ | |
$el.append("<tr><td>"+_.map(features,function(x){return s[x]}).join('</td><td>')+"</td><td>"+s[target]+"</td></tr>"); | |
}) | |
} | |
var calcError = function(samples,model,target){ | |
var total = 0; | |
var correct = 0; | |
_.each(samples,function(s){ | |
total++; | |
var pred = predict(model,s); | |
var actual = s[target]; | |
if(pred == actual){ | |
correct++; | |
} | |
}); | |
return correct/total; | |
} | |
$(document).ready(function(){ | |
var testModel = id3(examples,'play',features); | |
drawGraph(testModel,'canvas'); | |
renderSamples(samples,$("#samples"),testModel,'play',features); | |
renderTrainingData(examples,$("#training"),'play',features); | |
console.log("error"); | |
console.log(calcError(samples,testModel,'play')); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/underscore.js/1.3.3/underscore-min.js"></script> | |
<script src="https://d3js.org/d3.v2.js"></script> | |
<script src="https://www.google.com/jsapi?autoload={"modules":[{"name":"visualization","version":"1","packages":["orgchart"]}]}"></script> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" /> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment