Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active November 12, 2024 18:13
Show Gist options
  • Save carlosgmartin/7e451d124a87f8415ae20016fe1caeb3 to your computer and use it in GitHub Desktop.
Save carlosgmartin/7e451d124a87f8415ae20016fe1caeb3 to your computer and use it in GitHub Desktop.
import argparse
import ale_py
import cv2
import gymnasium as gym
import minari
import minigrid
import numpy as np
import sb3_contrib
import shortuuid
import stable_baselines3 as sb3
from stable_baselines3.common.callbacks import BaseCallback, EveryNTimesteps
from stable_baselines3.common.env_util import make_vec_env
gym.register_envs(ale_py)
def parse_args():
p = argparse.ArgumentParser(allow_abbrev=False)
p.add_argument("--env", default="ALE/Breakout-v5")
p.add_argument("--model", default="PPO", choices=["PPO", "TRPO", "A2C"])
p.add_argument("--timesteps", type=int, default=10**5)
p.add_argument("--period", type=int, default=500)
p.add_argument("--record_episodes", type=int, default=1)
p.add_argument("--num_envs", type=int, default=4)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--verbose", type=int, default=1)
p.add_argument("--dataset_id")
p.add_argument("--render", type=int, default=1)
return p.parse_args()
def make_env(args):
if args.env.startswith("ALE/"):
env = gym.make(args.env, render_mode="rgb_array")
env = gym.wrappers.AtariPreprocessing(env, frame_skip=1)
env = gym.wrappers.FrameStackObservation(env, stack_size=4)
elif args.env.startswith("MiniGrid-"):
# https://github.com/Farama-Foundation/Minigrid/issues/455
env = gym.make(args.env, render_mode="rgb_array")
env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
else:
env = gym.make(args.env, render_mode="rgb_array")
return env
def make_policy(env, args):
space = env.observation_space
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1:
return "MlpPolicy"
elif isinstance(space, gym.spaces.Box) and len(space.shape) == 3:
return "CnnPolicy"
else:
raise NotImplementedError
def make_model(policy, args):
vec_env = make_vec_env(
make_env,
env_kwargs={"args": args},
n_envs=args.num_envs,
)
match args.model:
case "A2C":
return sb3.A2C(policy, vec_env, verbose=args.verbose)
case "PPO":
return sb3.PPO(policy, vec_env, verbose=args.verbose)
case "TRPO":
return sb3_contrib.TRPO(policy, vec_env, verbose=args.verbose)
case _:
raise NotImplementedError
def get_dataset_id(args):
if args.dataset_id is None:
return f"{args.env}-{args.model}-training-{shortuuid.uuid()}-v0"
else:
return args.dataset_id
class Callback(BaseCallback):
def __init__(self, env, verbose=0):
super().__init__(verbose)
self.env = env
self.total_episodes = 0
def _on_step(self):
run_episode(self.env, self.model, render=False)
self.total_episodes += 1
return True
def run_episode(env, model, render):
steps = 0
score = 0
observation, info = env.reset()
while True:
if render:
im = env.render()
cv2.imshow("window", im[:, :, ::-1]) # type: ignore
cv2.waitKey(1)
action, _ = model.predict(observation)
observation, reward, terminated, truncated, info = env.step(action)
steps += 1
score += reward # type: ignore
if terminated or truncated:
break
return {
"score": score,
"steps": steps,
"terminated": terminated,
"truncated": truncated,
}
def main():
args = parse_args()
env = make_env(args)
collector = minari.DataCollector(env)
if args.verbose >= 1:
print(f"env: {env}")
policy = make_policy(env, args)
if args.verbose >= 1:
print(f"policy: {policy}")
model = make_model(policy, args)
if args.verbose >= 1:
print(f"model: {model}")
model.learn(
total_timesteps=args.timesteps,
callback=EveryNTimesteps(
n_steps=args.period, callback=Callback(collector)
),
)
dataset_id = get_dataset_id(args)
print(f"dataset id: {dataset_id}")
dataset = collector.create_dataset(
dataset_id=dataset_id,
eval_env=env,
algorithm_name=args.model,
author="author",
author_email="[email protected]",
code_permalink="example.com",
description=f"Episodes from training {args.model} on {args.env}.",
)
if args.verbose >= 1:
print(f"total episodes: {dataset.total_episodes}")
print(f"total steps: {dataset.total_steps}")
episodes = dataset.iterate_episodes()
scores = np.array([episode.rewards.sum() for episode in episodes])
print(f"mean score: {scores.mean()}")
if args.render:
while True:
print(run_episode(env, model, render=True))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment