multi_agent_parameter_sharing.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from ray import tune
  2. from ray.tune.registry import register_env
  3. from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
  4. from pettingzoo.sisl import waterworld_v0
  5. # Based on code from github.com/parametersharingmadrl/parametersharingmadrl
  6. if __name__ == "__main__":
  7. # RDQN - Rainbow DQN
  8. # ADQN - Apex DQN
  9. register_env("waterworld", lambda _: PettingZooEnv(waterworld_v0.env()))
  10. tune.run(
  11. "APEX_DDPG",
  12. stop={"episodes_total": 60000},
  13. checkpoint_freq=10,
  14. config={
  15. # Enviroment specific.
  16. "env": "waterworld",
  17. # General
  18. "num_gpus": 1,
  19. "num_workers": 2,
  20. "num_envs_per_worker": 8,
  21. "learning_starts": 1000,
  22. "buffer_size": int(1e5),
  23. "compress_observations": True,
  24. "rollout_fragment_length": 20,
  25. "train_batch_size": 512,
  26. "gamma": .99,
  27. "n_step": 3,
  28. "lr": .0001,
  29. "prioritized_replay_alpha": 0.5,
  30. "final_prioritized_replay_beta": 1.0,
  31. "target_network_update_freq": 50000,
  32. "timesteps_per_iteration": 25000,
  33. # Method specific.
  34. "multiagent": {
  35. # We only have one policy (calling it "shared").
  36. # Class, obs/act-spaces, and config will be derived
  37. # automatically.
  38. "policies": {"shared_policy"},
  39. # Always use "shared" policy.
  40. "policy_mapping_fn": (
  41. lambda agent_id, episode, **kwargs: "shared_policy"),
  42. },
  43. },
  44. )