Last active
February 10, 2017 00:06
-
-
Save hyponymous/32b23a355904e7a25a78205d34ded1c8 to your computer and use it in GitHub Desktop.
Gridworld I — SARSA(λ)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* global d3, _ */ | |
var CONFIG = { | |
transitionDuration: 0, | |
stepDelay: 10, | |
// learning rate | |
alpha: 0.1, | |
// deviation from greediness | |
epsilon: 1e-2, | |
// reward discount factor | |
gamma: 0.9, | |
// eligibility discount factor | |
lambda: 0.9 | |
}; | |
var ACTIONS = [ | |
[0, -1], | |
[1, 0], | |
[0, 1], | |
[-1, 0], | |
]; | |
function sum(arr) { | |
return arr.reduce((accum, val) => accum + val, 0); | |
} | |
function max(arr) { | |
return arr.reduce((accum, val) => Math.max(accum, val), Number.NEGATIVE_INFINITY); | |
} | |
function eGreedy(weights, epsilon) { | |
var indices; | |
if (Math.random() < epsilon) { | |
indices = _.range(weights.length); | |
} else { | |
var largest = max(weights); | |
indices = _.range(weights.length).filter(i => weights[i] === largest); | |
} | |
return _.sample(indices); | |
} | |
function GridView(el, dims) { | |
this.CELL_SIZE = 32; | |
var width = this.CELL_SIZE * dims[0] + 2; | |
var height = this.CELL_SIZE * dims[1] + 2; | |
this.svg = d3.select(el).append('svg') | |
.attr('width', width) | |
.attr('height', height); | |
} | |
Object.assign(GridView.prototype, { | |
render: function({grid, walls, treasure, agent}, qValues, eligibility) { | |
var CELL_SIZE = this.CELL_SIZE; | |
var HALF_INNER_CELL = (CELL_SIZE - 4.0) / 2.0; | |
var x = offset => { return d => offset + CELL_SIZE * d[0]; }; | |
var y = offset => { return d => offset + CELL_SIZE * d[1]; }; | |
var svg = this.svg; | |
function renderGridRects(data, klass) { | |
var boxes = svg.selectAll(['rect', klass].join('.')) | |
.data(data); | |
boxes.enter() | |
.append('rect') | |
.attr('class', klass) | |
.attr('width', CELL_SIZE - 2) | |
.attr('height', CELL_SIZE - 2) | |
.attr('x', x(2)) | |
.attr('y', y(2)); | |
boxes.exit() | |
.remove(); | |
} | |
function renderCircles(data, klass) { | |
var r = d => (d[2] * HALF_INNER_CELL); | |
var cx = d => (x(3 + HALF_INNER_CELL)(d)); | |
var cy = d => (y(3 + HALF_INNER_CELL)(d)); | |
var circle = svg.selectAll(['circle', klass].join('.')) | |
.data(data); | |
circle.enter() | |
.append('circle') | |
.attr('class', klass) | |
.attr('r', r) | |
.attr('cx', cx) | |
.attr('cy', cy); | |
circle | |
.transition() | |
.duration(CONFIG.transitionDuration) | |
.attr('r', r) | |
.attr('cx', cx) | |
.attr('cy', cy); | |
circle.exit() | |
.remove(); | |
} | |
function renderRects(data, klass) { | |
var bar = svg.selectAll(['rect', klass].join('.')) | |
.data(data); | |
bar.enter() | |
.append('rect') | |
.attr('class', klass) | |
.attr('width', d => d.width) | |
.attr('height', d => d.height) | |
.attr('x', d => d.x) | |
.attr('y', d => d.y); | |
bar | |
.transition() | |
.duration(CONFIG.transitionDuration) | |
.attr('width', d => d.width) | |
.attr('height', d => d.height) | |
.attr('x', d => d.x) | |
.attr('y', d => d.y); | |
bar.exit() | |
.remove(); | |
} | |
renderGridRects(grid, 'grid-square'); | |
renderGridRects(walls, 'wall-square'); | |
// eligibility and qValues are in kind of an ungainly format, so unroll | |
// them into an array of rects | |
function toRects({ mapping, filter, getLongDim, inset }) { | |
return _.chain(mapping) | |
.toPairs() | |
.filter(filter) | |
.map(pair => { | |
var coords = pair[0].split(',').map(x => parseInt(x)); | |
return pair[1].map((value, i) => { | |
var isHoriz = (i === 0 || i === 2); | |
var long = getLongDim(value); | |
var offset = 1.0 + (CELL_SIZE - long) / 2.0; | |
return { | |
width: isHoriz ? long : 2.0, | |
height: isHoriz ? 2.0 : long, | |
x: CELL_SIZE * coords[0] + [offset, CELL_SIZE - inset, offset, inset][i], | |
y: CELL_SIZE * coords[1] + [inset, offset, CELL_SIZE - inset, offset][i] | |
}; | |
}); | |
}) | |
.flatten() | |
.value(); | |
} | |
renderRects( | |
toRects({ | |
mapping: eligibility, | |
filter: e => { | |
return sum(e[1]) > 0.01; | |
}, | |
getLongDim: value => CELL_SIZE * value, | |
inset: 4.0 | |
}), | |
'eligibility'); | |
renderRects( | |
toRects({ | |
mapping: qValues, | |
filter: _.identity, | |
getLongDim: value => { | |
return Math.max(0.0, | |
CELL_SIZE * (1.0 / (1.0 + Math.exp(-1.9 * (value - 1.6))))); | |
}, | |
inset: 2.0 | |
}), | |
'q-value'); | |
renderCircles([agent.position].map(pos => [pos[0], pos[1], 1.0]), 'agent'); | |
renderCircles([treasure.position].map(pos => [pos[0], pos[1], 1.0]), 'treasure'); | |
} | |
}); | |
function collidesWith(posA) { | |
return function(posB) { | |
return posA[0] === posB[0] && posA[1] === posB[1]; | |
}; | |
} | |
function spin({ dims, view, model, startPosition, qValues, eligibility }) { | |
function samplePolicy(state) { | |
return eGreedy(qValues[state], CONFIG.epsilon); | |
} | |
function Q(state, actionIndex) { return qValues[state][actionIndex]; } | |
function setQ(state, actionIndex, value) { qValues[state][actionIndex] = value; } | |
function Z(state, actionIndex) { return eligibility[state][actionIndex]; } | |
function setZ(state, actionIndex, value) { eligibility[state][actionIndex] = value; } | |
var actionIndex = samplePolicy(model.agent.position); | |
var episodeCount = 0; | |
var moveCount = 0; | |
(function step() { | |
view.render(model, qValues, eligibility); | |
// transition | |
if (collidesWith(model.agent.position)(model.treasure.position)) { | |
// teleport back to start | |
model.agent.position = startPosition; | |
console.log(episodeCount++, moveCount); | |
moveCount = 0; | |
} else { | |
var currentState = model.agent.position; | |
var nextState = _.clone(model.agent.position); | |
// epsilon-greedy policy | |
var action = ACTIONS[actionIndex]; | |
nextState[0] += action[0]; | |
nextState[0] = Math.max(nextState[0], 0); | |
nextState[0] = Math.min(nextState[0], dims[0] - 1); | |
nextState[1] += action[1]; | |
nextState[1] = Math.max(nextState[1], 0); | |
nextState[1] = Math.min(nextState[1], dims[1] - 1); | |
// commit move if no collision (TODO: query environment -- might include wind) | |
if (_.some(model.walls, collidesWith(nextState))) { | |
nextState = currentState; | |
} | |
// rewards (TODO: query environment) | |
var reward = collidesWith(nextState)(model.treasure.position) ? | |
100 : | |
-1; | |
model.agent.satisfaction += reward; | |
var nextActionIndex = samplePolicy(nextState); | |
// SARSA update (for SARSA(lambda) need eligibility traces) | |
var delta = reward + CONFIG.gamma * Q(nextState, nextActionIndex) - Q(currentState, actionIndex); | |
setZ(currentState, actionIndex, Z(currentState, actionIndex) + 1.0); | |
_.each(qValues, (actionValues, state) => { | |
actionValues.forEach((value, i) => { | |
setQ(state, i, Q(state, i) + CONFIG.alpha * delta * Z(state, i)); | |
setZ(state, i, CONFIG.gamma * CONFIG.lambda * Z(state, i)); | |
}); | |
}); | |
actionIndex = nextActionIndex; | |
model.agent.position = nextState; | |
moveCount++; | |
} | |
// repeat | |
setTimeout(step, CONFIG.stepDelay); | |
})(); | |
} | |
(function() { | |
// initialize | |
var dims = [15, 15]; | |
var view = new GridView('body', dims); | |
var allCoords = _.flatten(_.range(dims[1]).map(i => _.range(dims[0]).map(j => [j, i]))); | |
var sampledPositions = _.sampleSize(allCoords, 2 + 0.1 * dims[0] * dims[1]); | |
var treasurePosition = sampledPositions.pop(); | |
var startPosition = sampledPositions.pop(); | |
var agentPosition = startPosition; | |
var model = { | |
dims, | |
grid: allCoords, | |
walls: sampledPositions, | |
treasure: { | |
position: treasurePosition | |
}, | |
agent: { | |
position: agentPosition, | |
// I can't get no | |
satisfaction: 0 | |
} | |
}; | |
var qValues = allCoords.reduce((qValues, coords) => { | |
// start with uniform value for all actions | |
qValues[coords] = [1.0, 1.0, 1.0, 1.0]; | |
return qValues; | |
}, {}); | |
var eligibility = allCoords.reduce((eligibility, coords) => { | |
eligibility[coords] = [0.0, 0.0, 0.0, 0.0]; | |
return eligibility; | |
}, {}); | |
spin({ dims, view, model, startPosition, qValues, eligibility }); | |
})(); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<head> | |
<meta charset="utf-8"> | |
<style> | |
.grid-square { | |
fill: #fff; | |
stroke: #aaa; | |
} | |
.wall-square { | |
fill: #666; | |
stroke: #222; | |
} | |
.treasure { | |
fill: #0c0; | |
stroke: #060; | |
} | |
.agent { | |
fill: #000; | |
} | |
.eligibility { | |
fill: #48e; | |
} | |
.q-value { | |
fill: #4c9; | |
} | |
</style> | |
</head> | |
<body> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.5.0/d3.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.4/lodash.min.js"></script> | |
<script src="gridworld.js"></script> | |
</body> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment