registry.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. """Registry of algorithm names for `rllib train --run=<alg_name>`"""
  2. import importlib
  3. import re
  4. import traceback
  5. from typing import Tuple, Type, TYPE_CHECKING, Union
  6. from ray.rllib.utils.deprecation import Deprecated
  7. if TYPE_CHECKING:
  8. from ray.rllib.algorithms.algorithm import Algorithm
  9. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  10. def _import_a2c():
  11. import ray.rllib.algorithms.a2c as a2c
  12. return a2c.A2C, a2c.A2C.get_default_config()
  13. def _import_a3c():
  14. import ray.rllib.algorithms.a3c as a3c
  15. return a3c.A3C, a3c.A3C.get_default_config()
  16. def _import_alpha_star():
  17. import ray.rllib.algorithms.alpha_star as alpha_star
  18. return alpha_star.AlphaStar, alpha_star.AlphaStar.get_default_config()
  19. def _import_alpha_zero():
  20. import ray.rllib.algorithms.alpha_zero as alpha_zero
  21. return alpha_zero.AlphaZero, alpha_zero.AlphaZero.get_default_config()
  22. def _import_apex():
  23. import ray.rllib.algorithms.apex_dqn as apex_dqn
  24. return apex_dqn.ApexDQN, apex_dqn.ApexDQN.get_default_config()
  25. def _import_apex_ddpg():
  26. import ray.rllib.algorithms.apex_ddpg as apex_ddpg
  27. return apex_ddpg.ApexDDPG, apex_ddpg.ApexDDPG.get_default_config()
  28. def _import_appo():
  29. import ray.rllib.algorithms.appo as appo
  30. return appo.APPO, appo.APPO.get_default_config()
  31. def _import_ars():
  32. import ray.rllib.algorithms.ars as ars
  33. return ars.ARS, ars.ARS.get_default_config()
  34. def _import_bandit_lints():
  35. from ray.rllib.algorithms.bandit.bandit import BanditLinTS
  36. return BanditLinTS, BanditLinTS.get_default_config()
  37. def _import_bandit_linucb():
  38. from ray.rllib.algorithms.bandit.bandit import BanditLinUCB
  39. return BanditLinUCB, BanditLinUCB.get_default_config()
  40. def _import_bc():
  41. import ray.rllib.algorithms.bc as bc
  42. return bc.BC, bc.BC.get_default_config()
  43. def _import_cql():
  44. import ray.rllib.algorithms.cql as cql
  45. return cql.CQL, cql.CQL.get_default_config()
  46. def _import_crr():
  47. from ray.rllib.algorithms import crr
  48. return crr.CRR, crr.CRR.get_default_config()
  49. def _import_ddpg():
  50. import ray.rllib.algorithms.ddpg as ddpg
  51. return ddpg.DDPG, ddpg.DDPG.get_default_config()
  52. def _import_ddppo():
  53. import ray.rllib.algorithms.ddppo as ddppo
  54. return ddppo.DDPPO, ddppo.DDPPO.get_default_config()
  55. def _import_dqn():
  56. import ray.rllib.algorithms.dqn as dqn
  57. return dqn.DQN, dqn.DQN.get_default_config()
  58. def _import_dreamer():
  59. import ray.rllib.algorithms.dreamer as dreamer
  60. return dreamer.Dreamer, dreamer.Dreamer.get_default_config()
  61. def _import_dreamerv3():
  62. import ray.rllib.algorithms.dreamerv3 as dreamerv3
  63. return dreamerv3.DreamerV3, dreamerv3.DreamerV3.get_default_config()
  64. def _import_dt():
  65. import ray.rllib.algorithms.dt as dt
  66. return dt.DT, dt.DT.get_default_config()
  67. def _import_es():
  68. import ray.rllib.algorithms.es as es
  69. return es.ES, es.ES.get_default_config()
  70. def _import_impala():
  71. import ray.rllib.algorithms.impala as impala
  72. return impala.Impala, impala.Impala.get_default_config()
  73. def _import_maddpg():
  74. import ray.rllib.algorithms.maddpg as maddpg
  75. return maddpg.MADDPG, maddpg.MADDPG.get_default_config()
  76. def _import_maml():
  77. import ray.rllib.algorithms.maml as maml
  78. return maml.MAML, maml.MAML.get_default_config()
  79. def _import_marwil():
  80. import ray.rllib.algorithms.marwil as marwil
  81. return marwil.MARWIL, marwil.MARWIL.get_default_config()
  82. def _import_mbmpo():
  83. import ray.rllib.algorithms.mbmpo as mbmpo
  84. return mbmpo.MBMPO, mbmpo.MBMPO.get_default_config()
  85. def _import_pg():
  86. import ray.rllib.algorithms.pg as pg
  87. return pg.PG, pg.PG.get_default_config()
  88. def _import_ppo():
  89. import ray.rllib.algorithms.ppo as ppo
  90. return ppo.PPO, ppo.PPO.get_default_config()
  91. def _import_qmix():
  92. import ray.rllib.algorithms.qmix as qmix
  93. return qmix.QMix, qmix.QMix.get_default_config()
  94. def _import_r2d2():
  95. import ray.rllib.algorithms.r2d2 as r2d2
  96. return r2d2.R2D2, r2d2.R2D2.get_default_config()
  97. def _import_random_agent():
  98. import ray.rllib.algorithms.random_agent as random_agent
  99. return random_agent.RandomAgent, random_agent.RandomAgent.get_default_config()
  100. def _import_rnnsac():
  101. from ray.rllib.algorithms import sac
  102. return sac.RNNSAC, sac.RNNSAC.get_default_config()
  103. def _import_sac():
  104. import ray.rllib.algorithms.sac as sac
  105. return sac.SAC, sac.SAC.get_default_config()
  106. def _import_simple_q():
  107. import ray.rllib.algorithms.simple_q as simple_q
  108. return simple_q.SimpleQ, simple_q.SimpleQ.get_default_config()
  109. def _import_slate_q():
  110. import ray.rllib.algorithms.slateq as slateq
  111. return slateq.SlateQ, slateq.SlateQ.get_default_config()
  112. def _import_td3():
  113. import ray.rllib.algorithms.td3 as td3
  114. return td3.TD3, td3.TD3.get_default_config()
  115. def _import_leela_chess_zero():
  116. import ray.rllib.algorithms.leela_chess_zero as lc0
  117. return lc0.LeelaChessZero, lc0.LeelaChessZero.get_default_config()
  118. ALGORITHMS = {
  119. "A2C": _import_a2c,
  120. "A3C": _import_a3c,
  121. "AlphaZero": _import_alpha_zero,
  122. "APEX": _import_apex,
  123. "APEX_DDPG": _import_apex_ddpg,
  124. "ARS": _import_ars,
  125. "BanditLinTS": _import_bandit_lints,
  126. "BanditLinUCB": _import_bandit_linucb,
  127. "BC": _import_bc,
  128. "CQL": _import_cql,
  129. "CRR": _import_crr,
  130. "ES": _import_es,
  131. "DDPG": _import_ddpg,
  132. "DDPPO": _import_ddppo,
  133. "DQN": _import_dqn,
  134. "Dreamer": _import_dreamer,
  135. "DreamerV3": _import_dreamerv3,
  136. "DT": _import_dt,
  137. "IMPALA": _import_impala,
  138. "APPO": _import_appo,
  139. "AlphaStar": _import_alpha_star,
  140. "MADDPG": _import_maddpg,
  141. "MAML": _import_maml,
  142. "MARWIL": _import_marwil,
  143. "MBMPO": _import_mbmpo,
  144. "PG": _import_pg,
  145. "PPO": _import_ppo,
  146. "QMIX": _import_qmix,
  147. "R2D2": _import_r2d2,
  148. "Random": _import_random_agent,
  149. "RNNSAC": _import_rnnsac,
  150. "SAC": _import_sac,
  151. "SimpleQ": _import_simple_q,
  152. "SlateQ": _import_slate_q,
  153. "TD3": _import_td3,
  154. "LeelaChessZero": _import_leela_chess_zero,
  155. }
  156. ALGORITHMS_CLASS_TO_NAME = {
  157. "A2C": "A2C",
  158. "A3C": "A3C",
  159. "AlphaZero": "AlphaZero",
  160. "ApexDQN": "APEX",
  161. "ApexDDPG": "APEX_DDPG",
  162. "ARS": "ARS",
  163. "BanditLinTS": "BanditLinTS",
  164. "BanditLinUCB": "BanditLinUCB",
  165. "BC": "BC",
  166. "CQL": "CQL",
  167. "CRR": "CRR",
  168. "ES": "ES",
  169. "DDPG": "DDPG",
  170. "DDPPO": "DDPPO",
  171. "DQN": "DQN",
  172. "Dreamer": "Dreamer",
  173. "DreamerV3": "DreamerV3",
  174. "DT": "DT",
  175. "Impala": "IMPALA",
  176. "APPO": "APPO",
  177. "AlphaStar": "AlphaStar",
  178. "MADDPG": "MADDPG",
  179. "MAML": "MAML",
  180. "MARWIL": "MARWIL",
  181. "MBMPO": "MBMPO",
  182. "PG": "PG",
  183. "PPO": "PPO",
  184. "QMix": "QMIX",
  185. "R2D2": "R2D2",
  186. "RandomAgent": "Random",
  187. "RNNSAC": "RNNSAC",
  188. "SAC": "SAC",
  189. "SimpleQ": "SimpleQ",
  190. "SlateQ": "SlateQ",
  191. "TD3": "TD3",
  192. "LeelaChessZero": "LeelaChessZero",
  193. }
  194. @Deprecated(
  195. new="ray.tune.registry.get_trainable_cls([algo name], return_config=False) and cls="
  196. "ray.tune.registry.get_trainable_cls([algo name]); cls.get_default_config();",
  197. error=False,
  198. )
  199. def get_algorithm_class(
  200. alg: str,
  201. return_config=False,
  202. ) -> Union[Type["Algorithm"], Tuple[Type["Algorithm"], "AlgorithmConfig"]]:
  203. """Returns the class of a known Algorithm given its name."""
  204. try:
  205. return _get_algorithm_class(alg, return_config=return_config)
  206. except ImportError:
  207. from ray.rllib.algorithms.mock import _algorithm_import_failed
  208. class_ = _algorithm_import_failed(traceback.format_exc())
  209. config = class_.get_default_config()
  210. if return_config:
  211. return class_, config
  212. return class_
  213. def _get_algorithm_class(alg: str) -> type:
  214. # This helps us get around a circular import (tune calls rllib._register_all when
  215. # checking if a rllib Trainable is registered)
  216. if alg in ALGORITHMS:
  217. return ALGORITHMS[alg]()[0]
  218. elif alg == "script":
  219. from ray.tune import script_runner
  220. return script_runner.ScriptRunner
  221. elif alg == "__fake":
  222. from ray.rllib.algorithms.mock import _MockTrainer
  223. return _MockTrainer
  224. elif alg == "__sigmoid_fake_data":
  225. from ray.rllib.algorithms.mock import _SigmoidFakeData
  226. return _SigmoidFakeData
  227. elif alg == "__parameter_tuning":
  228. from ray.rllib.algorithms.mock import _ParameterTuningTrainer
  229. return _ParameterTuningTrainer
  230. else:
  231. raise Exception("Unknown algorithm {}.".format(alg))
  232. # Mapping from policy name to where it is located, relative to rllib.algorithms.
  233. # TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list
  234. # all the TF eager policies here.
  235. POLICIES = {
  236. "A3CTF1Policy": "a3c.a3c_tf_policy",
  237. "A3CTF2Policy": "a3c.a3c_tf_policy",
  238. "A3CTorchPolicy": "a3c.a3c_torch_policy",
  239. "AlphaZeroPolicy": "alpha_zero.alpha_zero_policy",
  240. "APPOTF1Policy": "appo.appo_tf_policy",
  241. "APPOTF2Policy": "appo.appo_tf_policy",
  242. "APPOTorchPolicy": "appo.appo_torch_policy",
  243. "ARSTFPolicy": "ars.ars_tf_policy",
  244. "ARSTorchPolicy": "ars.ars_torch_policy",
  245. "BanditTFPolicy": "bandit.bandit_tf_policy",
  246. "BanditTorchPolicy": "bandit.bandit_torch_policy",
  247. "CQLTFPolicy": "cql.cql_tf_policy",
  248. "CQLTorchPolicy": "cql.cql_torch_policy",
  249. "CRRTorchPolicy": "crr.torch.crr_torch_policy",
  250. "DDPGTF1Policy": "ddpg.ddpg_tf_policy",
  251. "DDPGTF2Policy": "ddpg.ddpg_tf_policy",
  252. "DDPGTorchPolicy": "ddpg.ddpg_torch_policy",
  253. "DQNTFPolicy": "dqn.dqn_tf_policy",
  254. "DQNTorchPolicy": "dqn.dqn_torch_policy",
  255. "DreamerTorchPolicy": "dreamer.dreamer_torch_policy",
  256. "DTTorchPolicy": "dt.dt_torch_policy",
  257. "ESTFPolicy": "es.es_tf_policy",
  258. "ESTorchPolicy": "es.es_torch_policy",
  259. "ImpalaTF1Policy": "impala.impala_tf_policy",
  260. "ImpalaTF2Policy": "impala.impala_tf_policy",
  261. "ImpalaTorchPolicy": "impala.impala_torch_policy",
  262. "MADDPGTFPolicy": "maddpg.maddpg_tf_policy",
  263. "MAMLTF1Policy": "maml.maml_tf_policy",
  264. "MAMLTF2Policy": "maml.maml_tf_policy",
  265. "MAMLTorchPolicy": "maml.maml_torch_policy",
  266. "MARWILTF1Policy": "marwil.marwil_tf_policy",
  267. "MARWILTF2Policy": "marwil.marwil_tf_policy",
  268. "MARWILTorchPolicy": "marwil.marwil_torch_policy",
  269. "MBMPOTorchPolicy": "mbmpo.mbmpo_torch_policy",
  270. "PGTF1Policy": "pg.pg_tf_policy",
  271. "PGTF2Policy": "pg.pg_tf_policy",
  272. "PGTorchPolicy": "pg.pg_torch_policy",
  273. "QMixTorchPolicy": "qmix.qmix_policy",
  274. "R2D2TFPolicy": "r2d2.r2d2_tf_policy",
  275. "R2D2TorchPolicy": "r2d2.r2d2_torch_policy",
  276. "SACTFPolicy": "sac.sac_tf_policy",
  277. "SACTorchPolicy": "sac.sac_torch_policy",
  278. "RNNSACTorchPolicy": "sac.rnnsac_torch_policy",
  279. "SimpleQTF1Policy": "simple_q.simple_q_tf_policy",
  280. "SimpleQTF2Policy": "simple_q.simple_q_tf_policy",
  281. "SimpleQTorchPolicy": "simple_q.simple_q_torch_policy",
  282. "SlateQTFPolicy": "slateq.slateq_tf_policy",
  283. "SlateQTorchPolicy": "slateq.slateq_torch_policy",
  284. "PPOTF1Policy": "ppo.ppo_tf_policy",
  285. "PPOTF2Policy": "ppo.ppo_tf_policy",
  286. "PPOTorchPolicy": "ppo.ppo_torch_policy",
  287. }
  288. def get_policy_class_name(policy_class: type):
  289. """Returns a string name for the provided policy class.
  290. Args:
  291. policy_class: RLlib policy class, e.g. A3CTorchPolicy, DQNTFPolicy, etc.
  292. Returns:
  293. A string name uniquely mapped to the given policy class.
  294. """
  295. # TF2 policy classes may get automatically converted into new class types
  296. # that have eager tracing capability.
  297. # These policy classes have the "_traced" postfix in their names.
  298. # When checkpointing these policy classes, we should save the name of the
  299. # original policy class instead. So that users have the choice of turning
  300. # on eager tracing during inference time.
  301. name = re.sub("_traced$", "", policy_class.__name__)
  302. if name in POLICIES:
  303. return name
  304. return None
  305. def get_policy_class(name: str):
  306. """Return an actual policy class given the string name.
  307. Args:
  308. name: string name of the policy class.
  309. Returns:
  310. Actual policy class for the given name.
  311. """
  312. if name not in POLICIES:
  313. return None
  314. path = POLICIES[name]
  315. module = importlib.import_module("ray.rllib.algorithms." + path)
  316. if not hasattr(module, name):
  317. return None
  318. return getattr(module, name)