batch_norm_model.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """Example of using a custom model with batch norm."""
  2. import argparse
  3. import os
  4. import ray
  5. from ray import tune
  6. from ray.rllib.examples.models.batch_norm_model import BatchNormModel, \
  7. KerasBatchNormModel, TorchBatchNormModel
  8. from ray.rllib.models import ModelCatalog
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.test_utils import check_learning_achieved
  11. tf1, tf, tfv = try_import_tf()
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument(
  14. "--run",
  15. type=str,
  16. default="PPO",
  17. help="The RLlib-registered algorithm to use.")
  18. parser.add_argument(
  19. "--framework",
  20. choices=["tf", "tf2", "tfe", "torch"],
  21. default="tf",
  22. help="The DL framework specifier.")
  23. parser.add_argument(
  24. "--as-test",
  25. action="store_true",
  26. help="Whether this script should be run as a test: --stop-reward must "
  27. "be achieved within --stop-timesteps AND --stop-iters.")
  28. parser.add_argument(
  29. "--stop-iters",
  30. type=int,
  31. default=200,
  32. help="Number of iterations to train.")
  33. parser.add_argument(
  34. "--stop-timesteps",
  35. type=int,
  36. default=100000,
  37. help="Number of timesteps to train.")
  38. parser.add_argument(
  39. "--stop-reward",
  40. type=float,
  41. default=150.0,
  42. help="Reward at which we stop training.")
  43. if __name__ == "__main__":
  44. args = parser.parse_args()
  45. ray.init()
  46. ModelCatalog.register_custom_model(
  47. "bn_model", TorchBatchNormModel
  48. if args.framework == "torch" else KerasBatchNormModel
  49. if args.run != "PPO" else BatchNormModel)
  50. config = {
  51. "env": "Pendulum-v1" if args.run in ["DDPG", "SAC"] else "CartPole-v0",
  52. "model": {
  53. "custom_model": "bn_model",
  54. },
  55. "lr": 0.0003,
  56. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  57. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  58. "num_workers": 0,
  59. "framework": args.framework,
  60. }
  61. stop = {
  62. "training_iteration": args.stop_iters,
  63. "timesteps_total": args.stop_timesteps,
  64. "episode_reward_mean": args.stop_reward,
  65. }
  66. results = tune.run(args.run, stop=stop, config=config, verbose=2)
  67. if args.as_test:
  68. check_learning_achieved(results, args.stop_reward)
  69. ray.shutdown()