custom_tf_policy.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import argparse
  2. import os
  3. import ray
  4. from ray import tune
  5. from ray.rllib.agents.trainer_template import build_trainer
  6. from ray.rllib.evaluation.postprocessing import discount_cumsum
  7. from ray.rllib.policy.tf_policy_template import build_tf_policy
  8. from ray.rllib.utils.framework import try_import_tf
  9. tf1, tf, tfv = try_import_tf()
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--stop-iters", type=int, default=200)
  12. parser.add_argument("--num-cpus", type=int, default=0)
  13. def policy_gradient_loss(policy, model, dist_class, train_batch):
  14. logits, _ = model(train_batch)
  15. action_dist = dist_class(logits, model)
  16. return -tf.reduce_mean(
  17. action_dist.logp(train_batch["actions"]) * train_batch["returns"])
  18. def calculate_advantages(policy,
  19. sample_batch,
  20. other_agent_batches=None,
  21. episode=None):
  22. sample_batch["returns"] = discount_cumsum(sample_batch["rewards"], 0.99)
  23. return sample_batch
  24. # <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
  25. MyTFPolicy = build_tf_policy(
  26. name="MyTFPolicy",
  27. loss_fn=policy_gradient_loss,
  28. postprocess_fn=calculate_advantages,
  29. )
  30. # <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
  31. MyTrainer = build_trainer(
  32. name="MyCustomTrainer",
  33. default_policy=MyTFPolicy,
  34. )
  35. if __name__ == "__main__":
  36. args = parser.parse_args()
  37. ray.init(num_cpus=args.num_cpus or None)
  38. tune.run(
  39. MyTrainer,
  40. stop={"training_iteration": args.stop_iters},
  41. config={
  42. "env": "CartPole-v0",
  43. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  44. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  45. "num_workers": 2,
  46. "framework": "tf",
  47. })