eager_execution.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import argparse
  2. import os
  3. import random
  4. import ray
  5. from ray.rllib.agents.trainer_template import build_trainer
  6. from ray.rllib.examples.models.eager_model import EagerModel
  7. from ray.rllib.models import ModelCatalog
  8. from ray.rllib.policy.sample_batch import SampleBatch
  9. from ray.rllib.policy.tf_policy_template import build_tf_policy
  10. from ray.rllib.utils.framework import try_import_tf
  11. from ray.rllib.utils.test_utils import check_learning_achieved
  12. from ray import tune
  13. # Always import tensorflow using this utility function:
  14. tf1, tf, tfv = try_import_tf()
  15. # tf1: The installed tf1.x package OR the tf.compat.v1 module within
  16. # a 2.x tf installation.
  17. # tf: The installed tf package (whatever tf version was installed).
  18. # tfv: The tf version int (either 1 or 2).
  19. # To enable eager mode, do:
  20. # >> tf1.enable_eager_execution()
  21. # >> x = tf.Variable(0.0)
  22. # >> x.numpy()
  23. # 0.0
  24. # RLlib will automatically enable eager mode, if you specify your "framework"
  25. # config key to be either "tfe" or "tf2".
  26. # If you would like to remain in tf static-graph mode, but still use tf2.x's
  27. # new APIs (some of which are not supported by tf1.x), specify your "framework"
  28. # as "tf" and check for the version (tfv) to be 2:
  29. # Example:
  30. # >> def dense(x, W, b):
  31. # .. return tf.nn.sigmoid(tf.matmul(x, W) + b)
  32. #
  33. # >> @tf.function
  34. # >> def multilayer_perceptron(x, w0, b0):
  35. # .. return dense(x, w0, b0)
  36. # Also be careful to distinguish between tf1 and tf in your code. For example,
  37. # to create a placeholder:
  38. # >> tf1.placeholder(tf.float32, (2, )) # <- must use `tf1` here
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument(
  41. "--as-test",
  42. action="store_true",
  43. help="Whether this script should be run as a test: --stop-reward must "
  44. "be achieved within --stop-timesteps AND --stop-iters.")
  45. parser.add_argument(
  46. "--stop-iters",
  47. type=int,
  48. default=200,
  49. help="Number of iterations to train.")
  50. parser.add_argument(
  51. "--stop-timesteps",
  52. type=int,
  53. default=100000,
  54. help="Number of timesteps to train.")
  55. parser.add_argument(
  56. "--stop-reward",
  57. type=float,
  58. default=150.0,
  59. help="Reward at which we stop training.")
  60. def policy_gradient_loss(policy, model, dist_class, train_batch):
  61. """Example of using embedded eager execution in a custom loss.
  62. Here `compute_penalty` prints the actions and rewards for debugging, and
  63. also computes a (dummy) penalty term to add to the loss.
  64. """
  65. def compute_penalty(actions, rewards):
  66. assert tf.executing_eagerly()
  67. penalty = tf.reduce_mean(tf.cast(actions, tf.float32))
  68. if random.random() > 0.9:
  69. print("The eagerly computed penalty is", penalty, actions, rewards)
  70. return penalty
  71. logits, _ = model(train_batch)
  72. action_dist = dist_class(logits, model)
  73. actions = train_batch[SampleBatch.ACTIONS]
  74. rewards = train_batch[SampleBatch.REWARDS]
  75. penalty = tf.py_function(
  76. compute_penalty, [actions, rewards], Tout=tf.float32)
  77. return penalty - tf.reduce_mean(action_dist.logp(actions) * rewards)
  78. # <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
  79. MyTFPolicy = build_tf_policy(
  80. name="MyTFPolicy",
  81. loss_fn=policy_gradient_loss,
  82. )
  83. # <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
  84. MyTrainer = build_trainer(
  85. name="MyCustomTrainer",
  86. default_policy=MyTFPolicy,
  87. )
  88. if __name__ == "__main__":
  89. ray.init()
  90. args = parser.parse_args()
  91. ModelCatalog.register_custom_model("eager_model", EagerModel)
  92. config = {
  93. "env": "CartPole-v0",
  94. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  95. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  96. "num_workers": 0,
  97. "model": {
  98. "custom_model": "eager_model"
  99. },
  100. # Alternatively, use "tf2" here for enforcing TF version 2.x.
  101. "framework": "tfe",
  102. }
  103. stop = {
  104. "timesteps_total": args.stop_timesteps,
  105. "training_iteration": args.stop_iters,
  106. "episode_reward_mean": args.stop_reward,
  107. }
  108. results = tune.run(MyTrainer, stop=stop, config=config, verbose=1)
  109. if args.as_test:
  110. check_learning_achieved(results, args.stop_reward)
  111. ray.shutdown()