Skip to content

Instantly share code, notes, and snippets.

@myxyy
Last active August 13, 2023 13:59
Show Gist options
  • Save myxyy/5df680dcb00d0b51a86c71508f915950 to your computer and use it in GitHub Desktop.
Save myxyy/5df680dcb00d0b51a86c71508f915950 to your computer and use it in GitHub Desktop.
CartPole-ActorCritic
# 参考:
# https://www.dskomei.com/entry/2022/03/13/114756
# ゼロから作るDeepLearning4強化学習編
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, dim_state, num_actions):
super().__init__()
self.num_actions = num_actions
self.fc1 = nn.Linear(dim_state, 64)
self.fc2 = nn.Linear(64, 128)
self.fc3 = nn.Linear(128, 64)
self.actor_head = nn.Linear(64, num_actions)
self.critic_head = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.softmax = nn.Softmax()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
actor_out = self.actor_head(x)
state_value = self.critic_head(x)
action_prob = self.softmax(actor_out)
return action_prob, state_value
def select_action(model, state):
action_prob, state_value = model(state)
m = Categorical(action_prob)
action = m.sample()
return action.item(), m.log_prob(action), state_value
gamma = 0.9
model = ActorCritic(4, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
env = gym.make('CartPole-v1', render_mode='human')
observation, info = env.reset()
for _ in range(1000000):
state = torch.from_numpy(observation).reshape(1, -1)
action, action_log_prob, state_value = select_action(model, state)
observation, reward, terminated, truncated, info = env.step(action)
reward -= observation[0] * observation[0]
state_next = torch.from_numpy(observation).reshape(1, -1)
_, _, state_next_value = select_action(model, state_next)
with torch.no_grad():
state_value_target = reward + gamma * state_next_value * (0 if terminated or truncated else 1)
loss_critic = torch.pow(state_value_target - state_value, 2)
with torch.no_grad():
delta = state_value_target - state_value
loss_actor = - action_log_prob * delta
loss = loss_actor + loss_critic
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('\rloss:{:.5f}'.format(loss.item()), end='')
if terminated or truncated:
observation, info = env.reset()
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment