Last active
November 12, 2024 18:13
-
-
Save carlosgmartin/7e451d124a87f8415ae20016fe1caeb3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import ale_py | |
import cv2 | |
import gymnasium as gym | |
import minari | |
import minigrid | |
import numpy as np | |
import sb3_contrib | |
import shortuuid | |
import stable_baselines3 as sb3 | |
from stable_baselines3.common.callbacks import BaseCallback, EveryNTimesteps | |
from stable_baselines3.common.env_util import make_vec_env | |
gym.register_envs(ale_py) | |
def parse_args(): | |
p = argparse.ArgumentParser(allow_abbrev=False) | |
p.add_argument("--env", default="ALE/Breakout-v5") | |
p.add_argument("--model", default="PPO", choices=["PPO", "TRPO", "A2C"]) | |
p.add_argument("--timesteps", type=int, default=10**5) | |
p.add_argument("--period", type=int, default=500) | |
p.add_argument("--record_episodes", type=int, default=1) | |
p.add_argument("--num_envs", type=int, default=4) | |
p.add_argument("--seed", type=int, default=0) | |
p.add_argument("--verbose", type=int, default=1) | |
p.add_argument("--dataset_id") | |
p.add_argument("--render", type=int, default=1) | |
return p.parse_args() | |
def make_env(args): | |
if args.env.startswith("ALE/"): | |
env = gym.make(args.env, render_mode="rgb_array") | |
env = gym.wrappers.AtariPreprocessing(env, frame_skip=1) | |
env = gym.wrappers.FrameStackObservation(env, stack_size=4) | |
elif args.env.startswith("MiniGrid-"): | |
# https://github.com/Farama-Foundation/Minigrid/issues/455 | |
env = gym.make(args.env, render_mode="rgb_array") | |
env = minigrid.wrappers.RGBImgPartialObsWrapper(env) | |
env = minigrid.wrappers.ImgObsWrapper(env) | |
else: | |
env = gym.make(args.env, render_mode="rgb_array") | |
return env | |
def make_policy(env, args): | |
space = env.observation_space | |
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1: | |
return "MlpPolicy" | |
elif isinstance(space, gym.spaces.Box) and len(space.shape) == 3: | |
return "CnnPolicy" | |
else: | |
raise NotImplementedError | |
def make_model(policy, args): | |
vec_env = make_vec_env( | |
make_env, | |
env_kwargs={"args": args}, | |
n_envs=args.num_envs, | |
) | |
match args.model: | |
case "A2C": | |
return sb3.A2C(policy, vec_env, verbose=args.verbose) | |
case "PPO": | |
return sb3.PPO(policy, vec_env, verbose=args.verbose) | |
case "TRPO": | |
return sb3_contrib.TRPO(policy, vec_env, verbose=args.verbose) | |
case _: | |
raise NotImplementedError | |
def get_dataset_id(args): | |
if args.dataset_id is None: | |
return f"{args.env}-{args.model}-training-{shortuuid.uuid()}-v0" | |
else: | |
return args.dataset_id | |
class Callback(BaseCallback): | |
def __init__(self, env, verbose=0): | |
super().__init__(verbose) | |
self.env = env | |
self.total_episodes = 0 | |
def _on_step(self): | |
run_episode(self.env, self.model, render=False) | |
self.total_episodes += 1 | |
return True | |
def run_episode(env, model, render): | |
steps = 0 | |
score = 0 | |
observation, info = env.reset() | |
while True: | |
if render: | |
im = env.render() | |
cv2.imshow("window", im[:, :, ::-1]) # type: ignore | |
cv2.waitKey(1) | |
action, _ = model.predict(observation) | |
observation, reward, terminated, truncated, info = env.step(action) | |
steps += 1 | |
score += reward # type: ignore | |
if terminated or truncated: | |
break | |
return { | |
"score": score, | |
"steps": steps, | |
"terminated": terminated, | |
"truncated": truncated, | |
} | |
def main(): | |
args = parse_args() | |
env = make_env(args) | |
collector = minari.DataCollector(env) | |
if args.verbose >= 1: | |
print(f"env: {env}") | |
policy = make_policy(env, args) | |
if args.verbose >= 1: | |
print(f"policy: {policy}") | |
model = make_model(policy, args) | |
if args.verbose >= 1: | |
print(f"model: {model}") | |
model.learn( | |
total_timesteps=args.timesteps, | |
callback=EveryNTimesteps( | |
n_steps=args.period, callback=Callback(collector) | |
), | |
) | |
dataset_id = get_dataset_id(args) | |
print(f"dataset id: {dataset_id}") | |
dataset = collector.create_dataset( | |
dataset_id=dataset_id, | |
eval_env=env, | |
algorithm_name=args.model, | |
author="author", | |
author_email="[email protected]", | |
code_permalink="example.com", | |
description=f"Episodes from training {args.model} on {args.env}.", | |
) | |
if args.verbose >= 1: | |
print(f"total episodes: {dataset.total_episodes}") | |
print(f"total steps: {dataset.total_steps}") | |
episodes = dataset.iterate_episodes() | |
scores = np.array([episode.rewards.sum() for episode in episodes]) | |
print(f"mean score: {scores.mean()}") | |
if args.render: | |
while True: | |
print(run_episode(env, model, render=True)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment