Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Created June 1, 2023 22:42
Show Gist options
  • Save yejingxin/793d7397332aad4105016a8cd38b583a to your computer and use it in GitHub Desktop.
Save yejingxin/793d7397332aad4105016a8cd38b583a to your computer and use it in GitHub Desktop.
import ray
from ray.cluster_utils import Cluster
from collections import deque
# Starts a head-node for the cluster.
cluster = Cluster(
initialize_head=True,
head_node_args={
"resources": {"host": 1},
})
cluster.add_node(resources={"trainer": 1})
cluster.add_node(resources={"evaluator": 1})
cluster.add_node(resources={"actor": 1})
cluster.add_node(resources={"actor": 1})
cluster.add_node(resources={"replay_buffer": 1})
ray.init(address=cluster.address)
class Model:
def __init__(self):
self.weight = 'weight'
def predict(self, state, env):
return 'action'
def train(self, data):
return None
@ray.remote(resources={"trainer": 1})
class Trainer:
def __init__(self, model):
self.step = 0
self.model = model
def train(self, data):
self.step += 1
return 'train_step'
def get_model(self):
return self.model
#f"weight{self.step}"
@ray.remote(resources={"evaluator": 1})
def evaluate(state, model):
return model.predict(data)
#f"{weight}->action"
class Environment:
def __init__(self):
pass
def observe(self):
return 'state'
def act(self, act):
pass
@ray.remote(resources={"actor": 1})
class Actor:
def __init__(self, env):
self.env = env
def observe(self):
return self.env.observe()
def act(self,action):
return 'trajactory'
@ray.remote(resources={"replay_buffer": 1})
class ReplayBuffer:
def __init__(self, batch_size):
self.buffer = deque([])
self.batch_size = batch_size
self.max_size = 2 * self.batch_size
def ready(self):
return len(self.buffer) > self.batch_size
def get_data(self):
return self.buffer[:batch_size]
def add(self, trajactories):
self.buffer.extend(trajactory)
while len(self.buffer) > self.max_size:
self.buffer.popleft()
model = Model()
model_id = ray.put(model)
trainer = Trainer.remote(model_id)
replay_buffer = ReplayBuffer.remote(batch_size=2)
env = Environment()
env_id = ray.put(env)
actor = Actor.remote(env_id)
# pure function
# def predict(model, observation) -> policy, value
for step in range(2):
print(f'start step={step}')
model_id = trainer.get_model.remote()
observation_id = actor.observe.remote()
trajectory = []
while not env_id.should_stop():
trajectory.append((action_id, observation_id))
action_id, value_id = evaluate.remote(observation_id, model_id)
observation_id = actor.act.remote(action_id)
replay_buffer.add.remote(trajactory)
#trajactories_id = [actor.act.remote(action_id) for _ in range(5)]
data_id = replay_buffer.get_data.remote()
trainer.train.remote(data_id)
"""
one trainer and multiple actors
1. trainer send model to actors
2. actors send replay buffer
3. relay buffer send to trainer
"""
#resources={"evaluator": 1}
#result = []
#for _ in range(10):
# result.append(ray.get([evaluate.remote(_, 'w') for _ in range(5)]))
#print(result)
# create 2 cpu vm, set one as learner, the other as actor
@ray.remote(resource="learner")
def get_learer_host();
return host_name
@ray.remote(resource="actor")
def do_some_action(model_from_learner):
return "some thing"
leaner_id = get_learner_host.remote()
action_return_id = do_some_action.remote(leaner_id)
print(ray.get(action_return_id))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment