nested_action_spaces.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import argparse
  2. from gym.spaces import Dict, Tuple, Box, Discrete
  3. import os
  4. import ray
  5. import ray.tune as tune
  6. from ray.tune.registry import register_env
  7. from ray.rllib.examples.env.nested_space_repeat_after_me_env import \
  8. NestedSpaceRepeatAfterMeEnv
  9. from ray.rllib.utils.test_utils import check_learning_achieved
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument(
  12. "--run",
  13. type=str,
  14. default="PPO",
  15. help="The RLlib-registered algorithm to use.")
  16. parser.add_argument(
  17. "--framework",
  18. choices=["tf", "tf2", "tfe", "torch"],
  19. default="tf",
  20. help="The DL framework specifier.")
  21. parser.add_argument("--num-cpus", type=int, default=0)
  22. parser.add_argument(
  23. "--as-test",
  24. action="store_true",
  25. help="Whether this script should be run as a test: --stop-reward must "
  26. "be achieved within --stop-timesteps AND --stop-iters.")
  27. parser.add_argument(
  28. "--local-mode",
  29. action="store_true",
  30. help="Init Ray in local mode for easier debugging.")
  31. parser.add_argument(
  32. "--stop-iters",
  33. type=int,
  34. default=100,
  35. help="Number of iterations to train.")
  36. parser.add_argument(
  37. "--stop-timesteps",
  38. type=int,
  39. default=100000,
  40. help="Number of timesteps to train.")
  41. parser.add_argument(
  42. "--stop-reward",
  43. type=float,
  44. default=0.0,
  45. help="Reward at which we stop training.")
  46. if __name__ == "__main__":
  47. args = parser.parse_args()
  48. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  49. register_env("NestedSpaceRepeatAfterMeEnv",
  50. lambda c: NestedSpaceRepeatAfterMeEnv(c))
  51. config = {
  52. "env": "NestedSpaceRepeatAfterMeEnv",
  53. "env_config": {
  54. "space": Dict({
  55. "a": Tuple(
  56. [Dict({
  57. "d": Box(-10.0, 10.0, ()),
  58. "e": Discrete(2)
  59. })]),
  60. "b": Box(-10.0, 10.0, (2, )),
  61. "c": Discrete(4)
  62. }),
  63. },
  64. "entropy_coeff": 0.00005, # We don't want high entropy in this Env.
  65. "gamma": 0.0, # No history in Env (bandit problem).
  66. "lr": 0.0005,
  67. "num_envs_per_worker": 20,
  68. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  69. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  70. "num_sgd_iter": 4,
  71. "num_workers": 0,
  72. "vf_loss_coeff": 0.01,
  73. "framework": args.framework,
  74. }
  75. stop = {
  76. "training_iteration": args.stop_iters,
  77. "episode_reward_mean": args.stop_reward,
  78. "timesteps_total": args.stop_timesteps,
  79. }
  80. results = tune.run(args.run, config=config, stop=stop, verbose=1)
  81. if args.as_test:
  82. check_learning_achieved(results, args.stop_reward)
  83. ray.shutdown()