two_step_game.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf
  2. Configurations you can try:
  3. - normal policy gradients (PG)
  4. - contrib/MADDPG
  5. - QMIX
  6. See also: centralized_critic.py for centralized critic PPO on this game.
  7. """
  8. import argparse
  9. from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete
  10. import os
  11. import ray
  12. from ray import tune
  13. from ray.tune import register_env, grid_search
  14. from ray.rllib.env.multi_agent_env import ENV_STATE
  15. from ray.rllib.examples.env.two_step_game import TwoStepGame
  16. from ray.rllib.policy.policy import PolicySpec
  17. from ray.rllib.utils.test_utils import check_learning_achieved
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument(
  20. "--run",
  21. type=str,
  22. default="PG",
  23. help="The RLlib-registered algorithm to use.")
  24. parser.add_argument(
  25. "--framework",
  26. choices=["tf", "tf2", "tfe", "torch"],
  27. default="tf",
  28. help="The DL framework specifier.")
  29. parser.add_argument("--num-cpus", type=int, default=0)
  30. parser.add_argument(
  31. "--as-test",
  32. action="store_true",
  33. help="Whether this script should be run as a test: --stop-reward must "
  34. "be achieved within --stop-timesteps AND --stop-iters.")
  35. parser.add_argument(
  36. "--stop-iters",
  37. type=int,
  38. default=200,
  39. help="Number of iterations to train.")
  40. parser.add_argument(
  41. "--stop-timesteps",
  42. type=int,
  43. default=50000,
  44. help="Number of timesteps to train.")
  45. parser.add_argument(
  46. "--stop-reward",
  47. type=float,
  48. default=7.0,
  49. help="Reward at which we stop training.")
  50. parser.add_argument(
  51. "--local-mode",
  52. action="store_true",
  53. help="Init Ray in local mode for easier debugging.")
  54. if __name__ == "__main__":
  55. args = parser.parse_args()
  56. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  57. grouping = {
  58. "group_1": [0, 1],
  59. }
  60. obs_space = Tuple([
  61. Dict({
  62. "obs": MultiDiscrete([2, 2, 2, 3]),
  63. ENV_STATE: MultiDiscrete([2, 2, 2])
  64. }),
  65. Dict({
  66. "obs": MultiDiscrete([2, 2, 2, 3]),
  67. ENV_STATE: MultiDiscrete([2, 2, 2])
  68. }),
  69. ])
  70. act_space = Tuple([
  71. TwoStepGame.action_space,
  72. TwoStepGame.action_space,
  73. ])
  74. register_env(
  75. "grouped_twostep",
  76. lambda config: TwoStepGame(config).with_agent_groups(
  77. grouping, obs_space=obs_space, act_space=act_space))
  78. if args.run == "contrib/MADDPG":
  79. obs_space = Discrete(6)
  80. act_space = TwoStepGame.action_space
  81. config = {
  82. "learning_starts": 100,
  83. "env_config": {
  84. "actions_are_logits": True,
  85. },
  86. "multiagent": {
  87. "policies": {
  88. "pol1": PolicySpec(
  89. observation_space=obs_space,
  90. action_space=act_space,
  91. config={"agent_id": 0}),
  92. "pol2": PolicySpec(
  93. observation_space=obs_space,
  94. action_space=act_space,
  95. config={"agent_id": 1}),
  96. },
  97. "policy_mapping_fn": (
  98. lambda aid, **kwargs: "pol2" if aid else "pol1"),
  99. },
  100. "framework": args.framework,
  101. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  102. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  103. }
  104. group = False
  105. elif args.run == "QMIX":
  106. config = {
  107. "rollout_fragment_length": 4,
  108. "train_batch_size": 32,
  109. "exploration_config": {
  110. "epsilon_timesteps": 5000,
  111. "final_epsilon": 0.05,
  112. },
  113. "num_workers": 0,
  114. "mixer": grid_search([None, "qmix"]),
  115. "env_config": {
  116. "separate_state_space": True,
  117. "one_hot_state_encoding": True
  118. },
  119. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  120. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  121. }
  122. group = True
  123. else:
  124. config = {
  125. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  126. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  127. "framework": args.framework,
  128. }
  129. group = False
  130. stop = {
  131. "episode_reward_mean": args.stop_reward,
  132. "timesteps_total": args.stop_timesteps,
  133. "training_iteration": args.stop_iters,
  134. }
  135. config = dict(config, **{
  136. "env": "grouped_twostep" if group else TwoStepGame,
  137. })
  138. if args.as_test:
  139. config["seed"] = 1234
  140. results = tune.run(args.run, stop=stop, config=config, verbose=2)
  141. if args.as_test:
  142. check_learning_achieved(results, args.stop_reward)
  143. ray.shutdown()