Created
June 1, 2023 22:42
-
-
Save yejingxin/793d7397332aad4105016a8cd38b583a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
from ray.cluster_utils import Cluster | |
from collections import deque | |
# Starts a head-node for the cluster. | |
cluster = Cluster( | |
initialize_head=True, | |
head_node_args={ | |
"resources": {"host": 1}, | |
}) | |
cluster.add_node(resources={"trainer": 1}) | |
cluster.add_node(resources={"evaluator": 1}) | |
cluster.add_node(resources={"actor": 1}) | |
cluster.add_node(resources={"actor": 1}) | |
cluster.add_node(resources={"replay_buffer": 1}) | |
ray.init(address=cluster.address) | |
class Model: | |
def __init__(self): | |
self.weight = 'weight' | |
def predict(self, state, env): | |
return 'action' | |
def train(self, data): | |
return None | |
@ray.remote(resources={"trainer": 1}) | |
class Trainer: | |
def __init__(self, model): | |
self.step = 0 | |
self.model = model | |
def train(self, data): | |
self.step += 1 | |
return 'train_step' | |
def get_model(self): | |
return self.model | |
#f"weight{self.step}" | |
@ray.remote(resources={"evaluator": 1}) | |
def evaluate(state, model): | |
return model.predict(data) | |
#f"{weight}->action" | |
class Environment: | |
def __init__(self): | |
pass | |
def observe(self): | |
return 'state' | |
def act(self, act): | |
pass | |
@ray.remote(resources={"actor": 1}) | |
class Actor: | |
def __init__(self, env): | |
self.env = env | |
def observe(self): | |
return self.env.observe() | |
def act(self,action): | |
return 'trajactory' | |
@ray.remote(resources={"replay_buffer": 1}) | |
class ReplayBuffer: | |
def __init__(self, batch_size): | |
self.buffer = deque([]) | |
self.batch_size = batch_size | |
self.max_size = 2 * self.batch_size | |
def ready(self): | |
return len(self.buffer) > self.batch_size | |
def get_data(self): | |
return self.buffer[:batch_size] | |
def add(self, trajactories): | |
self.buffer.extend(trajactory) | |
while len(self.buffer) > self.max_size: | |
self.buffer.popleft() | |
model = Model() | |
model_id = ray.put(model) | |
trainer = Trainer.remote(model_id) | |
replay_buffer = ReplayBuffer.remote(batch_size=2) | |
env = Environment() | |
env_id = ray.put(env) | |
actor = Actor.remote(env_id) | |
# pure function | |
# def predict(model, observation) -> policy, value | |
for step in range(2): | |
print(f'start step={step}') | |
model_id = trainer.get_model.remote() | |
observation_id = actor.observe.remote() | |
trajectory = [] | |
while not env_id.should_stop(): | |
trajectory.append((action_id, observation_id)) | |
action_id, value_id = evaluate.remote(observation_id, model_id) | |
observation_id = actor.act.remote(action_id) | |
replay_buffer.add.remote(trajactory) | |
#trajactories_id = [actor.act.remote(action_id) for _ in range(5)] | |
data_id = replay_buffer.get_data.remote() | |
trainer.train.remote(data_id) | |
""" | |
one trainer and multiple actors | |
1. trainer send model to actors | |
2. actors send replay buffer | |
3. relay buffer send to trainer | |
""" | |
#resources={"evaluator": 1} | |
#result = [] | |
#for _ in range(10): | |
# result.append(ray.get([evaluate.remote(_, 'w') for _ in range(5)])) | |
#print(result) | |
# create 2 cpu vm, set one as learner, the other as actor | |
@ray.remote(resource="learner") | |
def get_learer_host(); | |
return host_name | |
@ray.remote(resource="actor") | |
def do_some_action(model_from_learner): | |
return "some thing" | |
leaner_id = get_learner_host.remote() | |
action_return_id = do_some_action.remote(leaner_id) | |
print(ray.get(action_return_id)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment