rollout_worker_custom_workflow.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """Example of using rollout worker classes directly to implement training.
  2. Instead of using the built-in Algorithm classes provided by RLlib, here we define
  3. a custom Policy class and manually coordinate distributed sample
  4. collection and policy optimization.
  5. """
  6. import argparse
  7. import gymnasium as gym
  8. import numpy as np
  9. import ray
  10. from ray import train, tune
  11. from ray.rllib.evaluation import RolloutWorker
  12. from ray.rllib.evaluation.metrics import collect_metrics
  13. from ray.rllib.policy.policy import Policy
  14. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples
  15. from ray.tune.execution.placement_groups import PlacementGroupFactory
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--gpu", action="store_true")
  18. parser.add_argument("--num-iters", type=int, default=20)
  19. parser.add_argument("--num-workers", type=int, default=2)
  20. parser.add_argument("--num-cpus", type=int, default=0)
  21. class CustomPolicy(Policy):
  22. """Example of a custom policy written from scratch.
  23. You might find it more convenient to extend TF/TorchPolicy instead
  24. for a real policy.
  25. """
  26. def __init__(self, observation_space, action_space, config):
  27. super().__init__(observation_space, action_space, config)
  28. self.config["framework"] = None
  29. # example parameter
  30. self.w = 1.0
  31. def compute_actions(
  32. self,
  33. obs_batch,
  34. state_batches=None,
  35. prev_action_batch=None,
  36. prev_reward_batch=None,
  37. info_batch=None,
  38. episodes=None,
  39. **kwargs
  40. ):
  41. # return random actions
  42. return np.array([self.action_space.sample() for _ in obs_batch]), [], {}
  43. def learn_on_batch(self, samples):
  44. # implement your learning code here
  45. return {}
  46. def update_some_value(self, w):
  47. # can also call other methods on policies
  48. self.w = w
  49. def get_weights(self):
  50. return {"w": self.w}
  51. def set_weights(self, weights):
  52. self.w = weights["w"]
  53. def training_workflow(config):
  54. # Setup policy and policy evaluation actors
  55. env = gym.make("CartPole-v1")
  56. policy = CustomPolicy(env.observation_space, env.action_space, {})
  57. workers = [
  58. ray.remote()(RolloutWorker).remote(
  59. env_creator=lambda c: gym.make("CartPole-v1"), policy=CustomPolicy
  60. )
  61. for _ in range(config["num_workers"])
  62. ]
  63. for _ in range(config["num_iters"]):
  64. # Broadcast weights to the policy evaluation workers
  65. weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()})
  66. for w in workers:
  67. w.set_weights.remote(weights)
  68. # Gather a batch of samples
  69. T1 = concat_samples(ray.get([w.sample.remote() for w in workers]))
  70. # Update the remote policy replicas and gather another batch of samples
  71. new_value = policy.w * 2.0
  72. for w in workers:
  73. w.for_policy.remote(lambda p: p.update_some_value(new_value))
  74. # Gather another batch of samples
  75. T2 = concat_samples(ray.get([w.sample.remote() for w in workers]))
  76. # Improve the policy using the T1 batch
  77. policy.learn_on_batch(T1)
  78. # Do some arbitrary updates based on the T2 batch
  79. policy.update_some_value(sum(T2["rewards"]))
  80. train.report(collect_metrics(remote_workers=workers))
  81. if __name__ == "__main__":
  82. args = parser.parse_args()
  83. ray.init(num_cpus=args.num_cpus or None)
  84. tune.Tuner(
  85. tune.with_resources(
  86. training_workflow,
  87. resources=PlacementGroupFactory(
  88. (
  89. [{"CPU": 1, "GPU": 1 if args.gpu else 0}]
  90. + [{"CPU": 1}] * args.num_workers
  91. )
  92. ),
  93. ),
  94. param_space={
  95. "num_workers": args.num_workers,
  96. "num_iters": args.num_iters,
  97. },
  98. run_config=train.RunConfig(
  99. verbose=1,
  100. ),
  101. )