rollout_worker_custom_workflow.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """Example of using rollout worker classes directly to implement training.
  2. Instead of using the built-in Trainer classes provided by RLlib, here we define
  3. a custom Policy class and manually coordinate distributed sample
  4. collection and policy optimization.
  5. """
  6. import argparse
  7. import gym
  8. import numpy as np
  9. import ray
  10. from ray import tune
  11. from ray.rllib.evaluation import RolloutWorker
  12. from ray.rllib.evaluation.metrics import collect_metrics
  13. from ray.rllib.policy.policy import Policy
  14. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  15. from ray.tune.utils.placement_groups import PlacementGroupFactory
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--gpu", action="store_true")
  18. parser.add_argument("--num-iters", type=int, default=20)
  19. parser.add_argument("--num-workers", type=int, default=2)
  20. parser.add_argument("--num-cpus", type=int, default=0)
  21. class CustomPolicy(Policy):
  22. """Example of a custom policy written from scratch.
  23. You might find it more convenient to extend TF/TorchPolicy instead
  24. for a real policy.
  25. """
  26. def __init__(self, observation_space, action_space, config):
  27. super().__init__(observation_space, action_space, config)
  28. self.config["framework"] = None
  29. # example parameter
  30. self.w = 1.0
  31. def compute_actions(self,
  32. obs_batch,
  33. state_batches=None,
  34. prev_action_batch=None,
  35. prev_reward_batch=None,
  36. info_batch=None,
  37. episodes=None,
  38. **kwargs):
  39. # return random actions
  40. return np.array(
  41. [self.action_space.sample() for _ in obs_batch]), [], {}
  42. def learn_on_batch(self, samples):
  43. # implement your learning code here
  44. return {}
  45. def update_some_value(self, w):
  46. # can also call other methods on policies
  47. self.w = w
  48. def get_weights(self):
  49. return {"w": self.w}
  50. def set_weights(self, weights):
  51. self.w = weights["w"]
  52. def training_workflow(config, reporter):
  53. # Setup policy and policy evaluation actors
  54. env = gym.make("CartPole-v0")
  55. policy = CustomPolicy(env.observation_space, env.action_space, {})
  56. workers = [
  57. RolloutWorker.as_remote().remote(
  58. env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy)
  59. for _ in range(config["num_workers"])
  60. ]
  61. for _ in range(config["num_iters"]):
  62. # Broadcast weights to the policy evaluation workers
  63. weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()})
  64. for w in workers:
  65. w.set_weights.remote(weights)
  66. # Gather a batch of samples
  67. T1 = SampleBatch.concat_samples(
  68. ray.get([w.sample.remote() for w in workers]))
  69. # Update the remote policy replicas and gather another batch of samples
  70. new_value = policy.w * 2.0
  71. for w in workers:
  72. w.for_policy.remote(lambda p: p.update_some_value(new_value))
  73. # Gather another batch of samples
  74. T2 = SampleBatch.concat_samples(
  75. ray.get([w.sample.remote() for w in workers]))
  76. # Improve the policy using the T1 batch
  77. policy.learn_on_batch(T1)
  78. # Do some arbitrary updates based on the T2 batch
  79. policy.update_some_value(sum(T2["rewards"]))
  80. reporter(**collect_metrics(remote_workers=workers))
  81. if __name__ == "__main__":
  82. args = parser.parse_args()
  83. ray.init(num_cpus=args.num_cpus or None)
  84. tune.run(
  85. training_workflow,
  86. resources_per_trial=PlacementGroupFactory(([{
  87. "CPU": 1,
  88. "GPU": 1 if args.gpu else 0
  89. }] + [{
  90. "CPU": 1
  91. }] * args.num_workers)),
  92. config={
  93. "num_workers": args.num_workers,
  94. "num_iters": args.num_iters,
  95. },
  96. verbose=1,
  97. )