registry.py 5.4 KB


  1. """Registry of algorithm names for `rllib train --run=<alg_name>`"""
  2. import traceback
  3. from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
  4. from ray.rllib.utils.deprecation import Deprecated
  5. def _import_a2c():
  6. from ray.rllib.agents import a3c
  7. return a3c.A2CTrainer, a3c.a2c.A2C_DEFAULT_CONFIG
  8. def _import_a3c():
  9. from ray.rllib.agents import a3c
  10. return a3c.A3CTrainer, a3c.DEFAULT_CONFIG
  11. def _import_apex():
  12. from ray.rllib.agents import dqn
  13. return dqn.ApexTrainer, dqn.apex.APEX_DEFAULT_CONFIG
  14. def _import_apex_ddpg():
  15. from ray.rllib.agents import ddpg
  16. return ddpg.ApexDDPGTrainer, ddpg.apex.APEX_DDPG_DEFAULT_CONFIG
  17. def _import_appo():
  18. from ray.rllib.agents import ppo
  19. return ppo.APPOTrainer, ppo.appo.DEFAULT_CONFIG
  20. def _import_ars():
  21. from ray.rllib.agents import ars
  22. return ars.ARSTrainer, ars.DEFAULT_CONFIG
  23. def _import_bc():
  24. from ray.rllib.agents import marwil
  25. return marwil.BCTrainer, marwil.DEFAULT_CONFIG
  26. def _import_cql():
  27. from ray.rllib.agents import cql
  28. return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
  29. def _import_ddpg():
  30. from ray.rllib.agents import ddpg
  31. return ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG
  32. def _import_ddppo():
  33. from ray.rllib.agents import ppo
  34. return ppo.DDPPOTrainer, ppo.DEFAULT_CONFIG
  35. def _import_dqn():
  36. from ray.rllib.agents import dqn
  37. return dqn.DQNTrainer, dqn.DEFAULT_CONFIG
  38. def _import_dreamer():
  39. from ray.rllib.agents import dreamer
  40. return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
  41. def _import_es():
  42. from ray.rllib.agents import es
  43. return es.ESTrainer, es.DEFAULT_CONFIG
  44. def _import_impala():
  45. from ray.rllib.agents import impala
  46. return impala.ImpalaTrainer, impala.DEFAULT_CONFIG
  47. def _import_maml():
  48. from ray.rllib.agents import maml
  49. return maml.MAMLTrainer, maml.DEFAULT_CONFIG
  50. def _import_marwil():
  51. from ray.rllib.agents import marwil
  52. return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
  53. def _import_mbmpo():
  54. from ray.rllib.agents import mbmpo
  55. return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG
  56. def _import_pg():
  57. from ray.rllib.agents import pg
  58. return pg.PGTrainer, pg.DEFAULT_CONFIG
  59. def _import_ppo():
  60. from ray.rllib.agents import ppo
  61. return ppo.PPOTrainer, ppo.DEFAULT_CONFIG
  62. def _import_qmix():
  63. from ray.rllib.agents import qmix
  64. return qmix.QMixTrainer, qmix.DEFAULT_CONFIG
  65. def _import_r2d2():
  66. from ray.rllib.agents import dqn
  67. return dqn.R2D2Trainer, dqn.R2D2_DEFAULT_CONFIG
  68. def _import_sac():
  69. from ray.rllib.agents import sac
  70. return sac.SACTrainer, sac.DEFAULT_CONFIG
  71. def _import_rnnsac():
  72. from ray.rllib.agents import sac
  73. return sac.RNNSACTrainer, sac.RNNSAC_DEFAULT_CONFIG
  74. def _import_simple_q():
  75. from ray.rllib.agents import dqn
  76. return dqn.SimpleQTrainer, dqn.simple_q.DEFAULT_CONFIG
  77. def _import_slate_q():
  78. from ray.rllib.agents import slateq
  79. return slateq.SlateQTrainer, slateq.DEFAULT_CONFIG
  80. def _import_td3():
  81. from ray.rllib.agents import ddpg
  82. return ddpg.TD3Trainer, ddpg.td3.TD3_DEFAULT_CONFIG
  83. ALGORITHMS = {
  84. "A2C": _import_a2c,
  85. "A3C": _import_a3c,
  86. "APEX": _import_apex,
  87. "APEX_DDPG": _import_apex_ddpg,
  88. "APPO": _import_appo,
  89. "ARS": _import_ars,
  90. "BC": _import_bc,
  91. "CQL": _import_cql,
  92. "ES": _import_es,
  93. "DDPG": _import_ddpg,
  94. "DDPPO": _import_ddppo,
  95. "DQN": _import_dqn,
  96. "SlateQ": _import_slate_q,
  97. "DREAMER": _import_dreamer,
  98. "IMPALA": _import_impala,
  99. "MAML": _import_maml,
  100. "MARWIL": _import_marwil,
  101. "MBMPO": _import_mbmpo,
  102. "PG": _import_pg,
  103. "PPO": _import_ppo,
  104. "QMIX": _import_qmix,
  105. "R2D2": _import_r2d2,
  106. "SAC": _import_sac,
  107. "RNNSAC": _import_rnnsac,
  108. "SimpleQ": _import_simple_q,
  109. "TD3": _import_td3,
  110. }
  111. def get_trainer_class(alg: str, return_config=False) -> type:
  112. """Returns the class of a known Trainer given its name."""
  113. try:
  114. return _get_trainer_class(alg, return_config=return_config)
  115. except ImportError:
  116. from ray.rllib.agents.mock import _trainer_import_failed
  117. class_ = _trainer_import_failed(traceback.format_exc())
  118. config = class_.get_default_config()
  119. if return_config:
  120. return class_, config
  121. return class_
  122. @Deprecated(new="get_trainer_class", error=False)
  123. def get_agent_class(alg: str) -> type:
  124. return get_trainer_class(alg)
  125. def _get_trainer_class(alg: str, return_config=False) -> type:
  126. if alg in ALGORITHMS:
  127. class_, config = ALGORITHMS[alg]()
  128. elif alg in CONTRIBUTED_ALGORITHMS:
  129. class_, config = CONTRIBUTED_ALGORITHMS[alg]()
  130. elif alg == "script":
  131. from ray.tune import script_runner
  132. class_, config = script_runner.ScriptRunner, {}
  133. elif alg == "__fake":
  134. from ray.rllib.agents.mock import _MockTrainer
  135. class_, config = _MockTrainer, _MockTrainer.get_default_config()
  136. elif alg == "__sigmoid_fake_data":
  137. from ray.rllib.agents.mock import _SigmoidFakeData
  138. class_, config = _SigmoidFakeData, _SigmoidFakeData.get_default_config(
  139. )
  140. elif alg == "__parameter_tuning":
  141. from ray.rllib.agents.mock import _ParameterTuningTrainer
  142. class_, config = _ParameterTuningTrainer, \
  143. _ParameterTuningTrainer.get_default_config()
  144. else:
  145. raise Exception(("Unknown algorithm {}.").format(alg))
  146. if return_config:
  147. return class_, config
  148. return class_