sumo_env_local.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. """ Example Algorithm for RLLIB + SUMO Utlis
  3. Author: Lara CODECA lara.codeca@gmail.com
  4. See:
  5. https://github.com/lcodeca/rllibsumoutils
  6. https://github.com/lcodeca/rllibsumodocker
  7. for further details.
  8. """
  9. import argparse
  10. from copy import deepcopy
  11. import logging
  12. import os
  13. import pathlib
  14. from pprint import pformat
  15. import ray
  16. from ray import air, tune
  17. from ray.rllib.algorithms.ppo import ppo
  18. from ray.rllib.examples.simulators.sumo import marlenvironment
  19. from ray.rllib.utils.test_utils import check_learning_achieved
  20. logging.basicConfig(level=logging.WARN)
  21. logger = logging.getLogger("ppotrain")
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument(
  24. "--sumo-connect-lib",
  25. type=str,
  26. default="libsumo",
  27. choices=["libsumo", "traci"],
  28. help="The SUMO connector to import. Requires the env variable SUMO_HOME set.",
  29. )
  30. parser.add_argument(
  31. "--sumo-gui",
  32. action="store_true",
  33. help="Enables the SUMO GUI. Possible only with TraCI connector.",
  34. )
  35. parser.add_argument(
  36. "--sumo-config-file",
  37. type=str,
  38. default=None,
  39. help="The SUMO configuration file for the scenario.",
  40. )
  41. parser.add_argument(
  42. "--from-checkpoint",
  43. type=str,
  44. default=None,
  45. help="Full path to a checkpoint file for restoring a previously saved "
  46. "Algorithm state.",
  47. )
  48. parser.add_argument("--num-workers", type=int, default=0)
  49. parser.add_argument(
  50. "--as-test",
  51. action="store_true",
  52. help="Whether this script should be run as a test: --stop-reward must "
  53. "be achieved within --stop-timesteps AND --stop-iters.",
  54. )
  55. parser.add_argument(
  56. "--stop-iters", type=int, default=10, help="Number of iterations to train."
  57. )
  58. parser.add_argument(
  59. "--stop-timesteps", type=int, default=1000000, help="Number of timesteps to train."
  60. )
  61. parser.add_argument(
  62. "--stop-reward",
  63. type=float,
  64. default=30000.0,
  65. help="Reward at which we stop training.",
  66. )
  67. if __name__ == "__main__":
  68. args = parser.parse_args()
  69. ray.init()
  70. tune.register_env("sumo_test_env", marlenvironment.env_creator)
  71. # Algorithm.
  72. policy_class = ppo.PPOTF1Policy
  73. config = (
  74. ppo.PPOConfig()
  75. .framework("tf")
  76. .rollouts(
  77. batch_mode="complete_episodes",
  78. num_rollout_workers=args.num_workers,
  79. )
  80. .training(
  81. gamma=0.99,
  82. lambda_=0.95,
  83. lr=0.001,
  84. sgd_minibatch_size=256,
  85. train_batch_size=4000,
  86. )
  87. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  88. .reporting(min_time_s_per_iteration=5)
  89. )
  90. # Load default Scenario configuration for the LEARNING ENVIRONMENT
  91. scenario_config = deepcopy(marlenvironment.DEFAULT_SCENARIO_CONFING)
  92. scenario_config["seed"] = 42
  93. scenario_config["log_level"] = "INFO"
  94. scenario_config["sumo_config"]["sumo_connector"] = args.sumo_connect_lib
  95. scenario_config["sumo_config"]["sumo_gui"] = args.sumo_gui
  96. if args.sumo_config_file is not None:
  97. scenario_config["sumo_config"]["sumo_cfg"] = args.sumo_config_file
  98. else:
  99. filename = "{}/simulators/sumo/scenario/sumo.cfg.xml".format(
  100. pathlib.Path(__file__).parent.absolute()
  101. )
  102. scenario_config["sumo_config"]["sumo_cfg"] = filename
  103. scenario_config["sumo_config"]["sumo_params"] = ["--collision.action", "warn"]
  104. scenario_config["sumo_config"]["trace_file"] = True
  105. scenario_config["sumo_config"]["end_of_sim"] = 3600 # [s]
  106. scenario_config["sumo_config"][
  107. "update_freq"
  108. ] = 10 # number of traci.simulationStep()
  109. # for each learning step.
  110. scenario_config["sumo_config"]["log_level"] = "INFO"
  111. logger.info("Scenario Configuration: \n %s", pformat(scenario_config))
  112. # Associate the agents with their configuration.
  113. agent_init = {
  114. "agent_0": deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
  115. "agent_1": deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
  116. }
  117. logger.info("Agents Configuration: \n %s", pformat(agent_init))
  118. # MARL Environment Init
  119. env_config = {
  120. "agent_init": agent_init,
  121. "scenario_config": scenario_config,
  122. }
  123. marl_env = marlenvironment.SUMOTestMultiAgentEnv(env_config)
  124. # Config for PPO from the MARLEnv.
  125. policies = {}
  126. for agent in marl_env.get_agents():
  127. agent_policy_params = {}
  128. policies[agent] = (
  129. policy_class,
  130. marl_env.get_obs_space(agent),
  131. marl_env.get_action_space(agent),
  132. agent_policy_params,
  133. )
  134. config.multi_agent(
  135. policies=policies,
  136. policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: agent_id,
  137. policies_to_train=["ppo_policy"],
  138. )
  139. config.environment("sumo_test_env", env_config=env_config)
  140. logger.info("PPO Configuration: \n %s", pformat(config.to_dict()))
  141. stop = {
  142. "training_iteration": args.stop_iters,
  143. "timesteps_total": args.stop_timesteps,
  144. "episode_reward_mean": args.stop_reward,
  145. }
  146. # Run the experiment.
  147. results = tune.Tuner(
  148. "PPO",
  149. param_space=config,
  150. run_config=air.RunConfig(
  151. stop=stop,
  152. verbose=1,
  153. checkpoint_config=air.CheckpointConfig(
  154. checkpoint_frequency=10,
  155. ),
  156. ),
  157. ).fit()
  158. # And check the results.
  159. if args.as_test:
  160. check_learning_achieved(results, args.stop_reward)
  161. ray.shutdown()