Created
April 20, 2020 14:16
-
-
Save EXJUSTICE/886cb62d5e981cc185cab4a3aa3aef27 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mc_prediction_q(env, num_episodes, generate_episode, gamma=1.0): | |
# Dictionary for returns | |
returns_sum = defaultdict(lambda: np.zeros(env.action_space.n)) | |
#Dictionary for Number of visits | |
N = defaultdict(lambda: np.zeros(env.action_space.n)) | |
#Action Values for State-Action Pair | |
Q = defaultdict(lambda: np.zeros(env.action_space.n)) | |
# loop over episodes | |
for i_episode in range(1, num_episodes+1): | |
# monitor progress | |
if i_episode % 1000 == 0: | |
print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="") | |
sys.stdout.flush() | |
# generate an episode - generate list of state-action tuples | |
episode = generate_episode(env) | |
# obtain the states, actions, and rewards. zip returns an iterator, see the explanation in previous cell | |
states, actions, rewards = zip(*episode) | |
# Discounting multiplier Prepare an array that discounts with gamma taken to ith power | |
discounts = np.array([gamma**i for i in range(len(rewards)+1)]) | |
# update the sum of the returns, number of visits, and action-value | |
# function estimates for each state-action pair in the episode | |
#Recall enumerate returns the index + state., here i is the TIMESTEP. | |
#We loop over timesteps i, | |
""" | |
Recall that zips return | |
Example: Single episode, with three timesteps | |
[((14, 9, False), 1, 0), ((16, 9, False), 1, 0), ((17, 9, False), 1, -1)] | |
states: ((14, 9, False), (16, 9, False), (17, 9, False)) | |
actions (1, 1, 1) | |
rewards (0, 0, -1) | |
""" | |
#Access states by timestep in an episode, and use index to access actions taken by timestep | |
for i in range(len(states)): | |
#Sum up the rewards of the state-action pairs for each timestep. | |
# Recall the rewards are simply all of the discounted returns FROM THAT MOMENT ONWARDS, | |
# THIS MAY NEED CLOSER ANALYSIS | |
returns_sum[states[i]][actions[i]] += sum(rewards[i:]*discounts[:-(1+i)]) | |
#Increment Number of visits for the state-aciton pair visited at a particular timestep. | |
N[states[i]][actions[i]] += 1.0 | |
#Q state action valus is simply the sum of total returns for that pair/ number of visits. | |
#So we update that | |
Q[states[i]][actions[i]] = returns_sum[states[i]][actions[i]] / N[states[i]][actions[i]] | |
#My own version that isnt confusing | |
return Q |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment