coin_game_env.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. ##########
  2. # Contribution by the Center on Long-Term Risk:
  3. # https://github.com/longtermrisk/marltoolbox
  4. ##########
  5. import argparse
  6. import os
  7. import ray
  8. from ray import air, tune
  9. from ray.rllib.algorithms.ppo import PPO
  10. from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame, AsymCoinGame
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument("--tf", action="store_true")
  13. parser.add_argument("--stop-iters", type=int, default=2000)
  14. def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
  15. train_n_replicates = 1 if debug else 1
  16. seeds = list(range(train_n_replicates))
  17. ray.init()
  18. stop = {
  19. "training_iteration": 2 if debug else stop_iters,
  20. }
  21. env_config = {
  22. "players_ids": ["player_red", "player_blue"],
  23. "max_steps": 20,
  24. "grid_size": 3,
  25. "get_additional_info": True,
  26. }
  27. rllib_config = {
  28. "env": AsymCoinGame if asymmetric_env else CoinGame,
  29. "env_config": env_config,
  30. "policies": {
  31. env_config["players_ids"][0]: (
  32. None,
  33. AsymCoinGame(env_config).observation_space,
  34. AsymCoinGame.action_space,
  35. {},
  36. ),
  37. env_config["players_ids"][1]: (
  38. None,
  39. AsymCoinGame(env_config).observation_space,
  40. AsymCoinGame.action_space,
  41. {},
  42. ),
  43. },
  44. "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id,
  45. # Size of batches collected from each worker.
  46. "rollout_fragment_length": 20,
  47. # Number of timesteps collected for each SGD round.
  48. # This defines the size of each SGD epoch.
  49. "train_batch_size": 512,
  50. "model": {
  51. "dim": env_config["grid_size"],
  52. "conv_filters": [
  53. [16, [3, 3], 1],
  54. [32, [3, 3], 1],
  55. ], # [Channel, [Kernel, Kernel], Stride]]
  56. },
  57. "lr": 5e-3,
  58. "seed": tune.grid_search(seeds),
  59. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  60. "framework": "tf" if tf else "torch",
  61. }
  62. tuner = tune.Tuner(
  63. PPO,
  64. param_space=rllib_config,
  65. run_config=air.RunConfig(
  66. name="PPO_AsymCG",
  67. stop=stop,
  68. checkpoint_config=air.CheckpointConfig(
  69. checkpoint_frequency=0, checkpoint_at_end=True
  70. ),
  71. ),
  72. )
  73. tuner.fit()
  74. ray.shutdown()
  75. if __name__ == "__main__":
  76. args = parser.parse_args()
  77. debug_mode = True
  78. use_asymmetric_env = False
  79. main(debug_mode, args.stop_iters, args.tf, use_asymmetric_env)