custom_torch_policy.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import argparse
  2. import os
  3. import ray
  4. from ray import tune
  5. from ray.rllib.agents.trainer_template import build_trainer
  6. from ray.rllib.policy.policy_template import build_policy_class
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument("--stop-iters", type=int, default=200)
  10. parser.add_argument("--num-cpus", type=int, default=0)
  11. def policy_gradient_loss(policy, model, dist_class, train_batch):
  12. logits, _ = model({SampleBatch.CUR_OBS: train_batch[SampleBatch.CUR_OBS]})
  13. action_dist = dist_class(logits, model)
  14. log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
  15. return -train_batch[SampleBatch.REWARDS].dot(log_probs)
  16. # <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
  17. MyTorchPolicy = build_policy_class(
  18. name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss)
  19. # <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
  20. MyTrainer = build_trainer(
  21. name="MyCustomTrainer",
  22. default_policy=MyTorchPolicy,
  23. )
  24. if __name__ == "__main__":
  25. args = parser.parse_args()
  26. ray.init(num_cpus=args.num_cpus or None)
  27. tune.run(
  28. MyTrainer,
  29. stop={"training_iteration": args.stop_iters},
  30. config={
  31. "env": "CartPole-v0",
  32. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  33. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  34. "num_workers": 2,
  35. "framework": "torch",
  36. })