iterated_prisoners_dilemma_env.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. ##########
  2. # Contribution by the Center on Long-Term Risk:
  3. # https://github.com/longtermrisk/marltoolbox
  4. ##########
  5. import argparse
  6. import os
  7. import ray
  8. from ray import tune
  9. from ray.rllib.agents.pg import PGTrainer
  10. from ray.rllib.examples.env.matrix_sequential_social_dilemma import \
  11. IteratedPrisonersDilemma
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument(
  14. "--framework",
  15. choices=["tf", "tf2", "tfe", "torch"],
  16. default="tf",
  17. help="The DL framework specifier.")
  18. parser.add_argument("--stop-iters", type=int, default=200)
  19. def main(debug, stop_iters=200, tf=False):
  20. train_n_replicates = 1 if debug else 1
  21. seeds = list(range(train_n_replicates))
  22. ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
  23. rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
  24. tune_analysis = tune.run(
  25. PGTrainer,
  26. config=rllib_config,
  27. stop=stop_config,
  28. checkpoint_freq=0,
  29. checkpoint_at_end=True,
  30. name="PG_IPD")
  31. ray.shutdown()
  32. return tune_analysis
  33. def get_rllib_config(seeds, debug=False, stop_iters=200, tf=False):
  34. stop_config = {
  35. "training_iteration": 2 if debug else stop_iters,
  36. }
  37. env_config = {
  38. "players_ids": ["player_row", "player_col"],
  39. "max_steps": 20,
  40. "get_additional_info": True,
  41. }
  42. rllib_config = {
  43. "env": IteratedPrisonersDilemma,
  44. "env_config": env_config,
  45. "multiagent": {
  46. "policies": {
  47. env_config["players_ids"][0]: (
  48. None, IteratedPrisonersDilemma.OBSERVATION_SPACE,
  49. IteratedPrisonersDilemma.ACTION_SPACE, {}),
  50. env_config["players_ids"][1]: (
  51. None, IteratedPrisonersDilemma.OBSERVATION_SPACE,
  52. IteratedPrisonersDilemma.ACTION_SPACE, {}),
  53. },
  54. "policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
  55. },
  56. "seed": tune.grid_search(seeds),
  57. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  58. "framework": args.framework,
  59. }
  60. return rllib_config, stop_config
  61. if __name__ == "__main__":
  62. debug_mode = True
  63. args = parser.parse_args()
  64. main(debug_mode, args.stop_iters, args.tf)