custom_torch_policy.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import argparse
  2. import os
  3. import ray
  4. from ray import air, tune
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.policy.policy_template import build_policy_class
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument("--stop-iters", type=int, default=200)
  10. parser.add_argument("--num-cpus", type=int, default=0)
  11. def policy_gradient_loss(policy, model, dist_class, train_batch):
  12. logits, _ = model({SampleBatch.CUR_OBS: train_batch[SampleBatch.CUR_OBS]})
  13. action_dist = dist_class(logits, model)
  14. log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
  15. return -train_batch[SampleBatch.REWARDS].dot(log_probs)
  16. # <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
  17. MyTorchPolicy = build_policy_class(
  18. name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss
  19. )
  20. # Create a new Algorithm using the Policy defined above.
  21. class MyAlgorithm(Algorithm):
  22. @classmethod
  23. def get_default_policy_class(cls, config):
  24. return MyTorchPolicy
  25. if __name__ == "__main__":
  26. args = parser.parse_args()
  27. ray.init(num_cpus=args.num_cpus or None)
  28. tuner = tune.Tuner(
  29. MyAlgorithm,
  30. run_config=air.RunConfig(
  31. stop={"training_iteration": args.stop_iters},
  32. ),
  33. param_space={
  34. "env": "CartPole-v1",
  35. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  36. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  37. "num_workers": 2,
  38. "framework": "torch",
  39. },
  40. )
  41. tuner.fit()