rlmodule_guide.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # flake8: noqa
  2. from ray.rllib.utils.annotations import override
  3. from ray.rllib.core.models.specs.typing import SpecType
  4. from ray.rllib.core.models.specs.specs_base import TensorSpec
  5. # __enabling-rlmodules-in-configs-begin__
  6. import torch
  7. from pprint import pprint
  8. from ray.rllib.algorithms.ppo import PPOConfig
  9. config = (
  10. PPOConfig()
  11. .framework("torch")
  12. .environment("CartPole-v1")
  13. .rl_module(_enable_rl_module_api=True)
  14. .training(_enable_learner_api=True)
  15. )
  16. algorithm = config.build()
  17. # run for 2 training steps
  18. for _ in range(2):
  19. result = algorithm.train()
  20. pprint(result)
  21. # __enabling-rlmodules-in-configs-end__
  22. # __constructing-rlmodules-sa-begin__
  23. import gymnasium as gym
  24. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  25. from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
  26. env = gym.make("CartPole-v1")
  27. spec = SingleAgentRLModuleSpec(
  28. module_class=DiscreteBCTorchModule,
  29. observation_space=env.observation_space,
  30. action_space=env.action_space,
  31. model_config_dict={"fcnet_hiddens": [64]},
  32. )
  33. module = spec.build()
  34. # __constructing-rlmodules-sa-end__
  35. # __constructing-rlmodules-ma-begin__
  36. import gymnasium as gym
  37. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  38. from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
  39. from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
  40. spec = MultiAgentRLModuleSpec(
  41. module_specs={
  42. "module_1": SingleAgentRLModuleSpec(
  43. module_class=DiscreteBCTorchModule,
  44. observation_space=gym.spaces.Box(low=-1, high=1, shape=(10,)),
  45. action_space=gym.spaces.Discrete(2),
  46. model_config_dict={"fcnet_hiddens": [32]},
  47. ),
  48. "module_2": SingleAgentRLModuleSpec(
  49. module_class=DiscreteBCTorchModule,
  50. observation_space=gym.spaces.Box(low=-1, high=1, shape=(5,)),
  51. action_space=gym.spaces.Discrete(2),
  52. model_config_dict={"fcnet_hiddens": [16]},
  53. ),
  54. },
  55. )
  56. marl_module = spec.build()
  57. # __constructing-rlmodules-ma-end__
  58. # __pass-specs-to-configs-sa-begin__
  59. import gymnasium as gym
  60. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  61. from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
  62. from ray.rllib.core.testing.bc_algorithm import BCConfigTest
  63. config = (
  64. BCConfigTest()
  65. .environment("CartPole-v1")
  66. .rl_module(
  67. _enable_rl_module_api=True,
  68. rl_module_spec=SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule),
  69. )
  70. .training(
  71. model={"fcnet_hiddens": [32, 32]},
  72. _enable_learner_api=True,
  73. )
  74. )
  75. algo = config.build()
  76. # __pass-specs-to-configs-sa-end__
  77. # __pass-specs-to-configs-ma-begin__
  78. import gymnasium as gym
  79. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  80. from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
  81. from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
  82. from ray.rllib.core.testing.bc_algorithm import BCConfigTest
  83. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  84. config = (
  85. BCConfigTest()
  86. .environment(MultiAgentCartPole, env_config={"num_agents": 2})
  87. .rl_module(
  88. _enable_rl_module_api=True,
  89. rl_module_spec=MultiAgentRLModuleSpec(
  90. module_specs=SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule)
  91. ),
  92. )
  93. .training(
  94. model={"fcnet_hiddens": [32, 32]},
  95. _enable_learner_api=True,
  96. )
  97. )
  98. # __pass-specs-to-configs-ma-end__
  99. # __convert-sa-to-ma-begin__
  100. import gymnasium as gym
  101. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  102. from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
  103. env = gym.make("CartPole-v1")
  104. spec = SingleAgentRLModuleSpec(
  105. module_class=DiscreteBCTorchModule,
  106. observation_space=env.observation_space,
  107. action_space=env.action_space,
  108. model_config_dict={"fcnet_hiddens": [64]},
  109. )
  110. module = spec.build()
  111. marl_module = module.as_multi_agent()
  112. # __convert-sa-to-ma-end__
  113. # __write-custom-sa-rlmodule-torch-begin__
  114. from typing import Mapping, Any
  115. from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
  116. from ray.rllib.core.rl_module.rl_module import RLModuleConfig
  117. from ray.rllib.utils.nested_dict import NestedDict
  118. import torch
  119. import torch.nn as nn
  120. class DiscreteBCTorchModule(TorchRLModule):
  121. def __init__(self, config: RLModuleConfig) -> None:
  122. super().__init__(config)
  123. def setup(self):
  124. input_dim = self.config.observation_space.shape[0]
  125. hidden_dim = self.config.model_config_dict["fcnet_hiddens"][0]
  126. output_dim = self.config.action_space.n
  127. self.policy = nn.Sequential(
  128. nn.Linear(input_dim, hidden_dim),
  129. nn.ReLU(),
  130. nn.Linear(hidden_dim, output_dim),
  131. )
  132. self.input_dim = input_dim
  133. def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
  134. with torch.no_grad():
  135. return self._forward_train(batch)
  136. def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
  137. with torch.no_grad():
  138. return self._forward_train(batch)
  139. def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
  140. action_logits = self.policy(batch["obs"])
  141. return {"action_dist": torch.distributions.Categorical(logits=action_logits)}
  142. # __write-custom-sa-rlmodule-torch-end__
  143. # __write-custom-sa-rlmodule-tf-begin__
  144. from typing import Mapping, Any
  145. from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
  146. from ray.rllib.core.rl_module.rl_module import RLModuleConfig
  147. from ray.rllib.utils.nested_dict import NestedDict
  148. import tensorflow as tf
  149. class DiscreteBCTfModule(TfRLModule):
  150. def __init__(self, config: RLModuleConfig) -> None:
  151. super().__init__(config)
  152. def setup(self):
  153. input_dim = self.config.observation_space.shape[0]
  154. hidden_dim = self.config.model_config_dict["fcnet_hiddens"][0]
  155. output_dim = self.config.action_space.n
  156. self.policy = tf.keras.Sequential(
  157. [
  158. tf.keras.layers.Dense(hidden_dim, activation="relu"),
  159. tf.keras.layers.Dense(output_dim),
  160. ]
  161. )
  162. self.input_dim = input_dim
  163. def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
  164. return self._forward_train(batch)
  165. def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
  166. return self._forward_train(batch)
  167. def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
  168. action_logits = self.policy(batch["obs"])
  169. return {"action_dist": tf.distributions.Categorical(logits=action_logits)}
  170. # __write-custom-sa-rlmodule-tf-end__
  171. # __extend-spec-checking-single-level-begin__
  172. class DiscreteBCTorchModule(TorchRLModule):
  173. ...
  174. @override(TorchRLModule)
  175. def input_specs_exploration(self) -> SpecType:
  176. # Enforce that input nested dict to exploration method has a key "obs"
  177. return ["obs"]
  178. @override(TorchRLModule)
  179. def output_specs_exploration(self) -> SpecType:
  180. # Enforce that output nested dict from exploration method has a key
  181. # "action_dist"
  182. return ["action_dist"]
  183. # __extend-spec-checking-single-level-end__
  184. # __extend-spec-checking-nested-begin__
  185. class DiscreteBCTorchModule(TorchRLModule):
  186. ...
  187. @override(TorchRLModule)
  188. def input_specs_exploration(self) -> SpecType:
  189. # Enforce that input nested dict to exploration method has a key "obs"
  190. # and within that key, it has a key "global" and "local". There should
  191. # also be a key "action_mask"
  192. return [("obs", "global"), ("obs", "local"), "action_mask"]
  193. # __extend-spec-checking-nested-end__
  194. # __extend-spec-checking-torch-specs-begin__
  195. class DiscreteBCTorchModule(TorchRLModule):
  196. ...
  197. @override(TorchRLModule)
  198. def input_specs_exploration(self) -> SpecType:
  199. # Enforce that input nested dict to exploration method has a key "obs"
  200. # and its value is a torch.Tensor with shape (b, h) where b is the
  201. # batch size (determined at run-time) and h is the hidden size
  202. # (fixed at 10).
  203. return {"obs": TensorSpec("b, h", h=10, framework="torch")}
  204. # __extend-spec-checking-torch-specs-end__
  205. # __extend-spec-checking-type-specs-begin__
  206. class DiscreteBCTorchModule(TorchRLModule):
  207. ...
  208. @override(TorchRLModule)
  209. def output_specs_exploration(self) -> SpecType:
  210. # Enforce that output nested dict from exploration method has a key
  211. # "action_dist" and its value is a torch.distribution.Categorical
  212. return {"action_dist": torch.distribution.Categorical}
  213. # __extend-spec-checking-type-specs-end__
  214. # __write-custom-marlmodule-shared-enc-begin__
  215. from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
  216. from ray.rllib.core.rl_module.marl_module import (
  217. MultiAgentRLModuleConfig,
  218. MultiAgentRLModule,
  219. )
  220. from ray.rllib.utils.nested_dict import NestedDict
  221. import torch
  222. import torch.nn as nn
  223. class BCTorchRLModuleWithSharedGlobalEncoder(TorchRLModule):
  224. """An RLModule with a shared encoder between agents for global observation."""
  225. def __init__(
  226. self, encoder: nn.Module, local_dim: int, hidden_dim: int, action_dim: int
  227. ) -> None:
  228. super().__init__(config=None)
  229. self.encoder = encoder
  230. self.policy_head = nn.Sequential(
  231. nn.Linear(hidden_dim + local_dim, hidden_dim),
  232. nn.ReLU(),
  233. nn.Linear(hidden_dim, action_dim),
  234. )
  235. def _forward_inference(self, batch):
  236. with torch.no_grad():
  237. return self._common_forward(batch)
  238. def _forward_exploration(self, batch):
  239. with torch.no_grad():
  240. return self._common_forward(batch)
  241. def _forward_train(self, batch):
  242. return self._common_forward(batch)
  243. def _common_forward(self, batch):
  244. obs = batch["obs"]
  245. global_enc = self.encoder(obs["global"])
  246. policy_in = torch.cat([global_enc, obs["local"]], dim=-1)
  247. action_logits = self.policy_head(policy_in)
  248. return {"action_dist": torch.distributions.Categorical(logits=action_logits)}
  249. class BCTorchMultiAgentModuleWithSharedEncoder(MultiAgentRLModule):
  250. def __init__(self, config: MultiAgentRLModuleConfig) -> None:
  251. super().__init__(config)
  252. def setup(self):
  253. module_specs = self.config.modules
  254. module_spec = next(iter(module_specs.values()))
  255. global_dim = module_spec.observation_space["global"].shape[0]
  256. hidden_dim = module_spec.model_config_dict["fcnet_hiddens"][0]
  257. shared_encoder = nn.Sequential(
  258. nn.Linear(global_dim, hidden_dim),
  259. nn.ReLU(),
  260. nn.Linear(hidden_dim, hidden_dim),
  261. )
  262. rl_modules = {}
  263. for module_id, module_spec in module_specs.items():
  264. rl_modules[module_id] = BCTorchRLModuleWithSharedGlobalEncoder(
  265. encoder=shared_encoder,
  266. local_dim=module_spec.observation_space["local"].shape[0],
  267. hidden_dim=hidden_dim,
  268. action_dim=module_spec.action_space.n,
  269. )
  270. self._rl_modules = rl_modules
  271. # __write-custom-marlmodule-shared-enc-end__
  272. # __pass-custom-marlmodule-shared-enc-begin__
  273. import gymnasium as gym
  274. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  275. from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
  276. spec = MultiAgentRLModuleSpec(
  277. marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
  278. module_specs={
  279. "local_2d": SingleAgentRLModuleSpec(
  280. observation_space=gym.spaces.Dict(
  281. {
  282. "global": gym.spaces.Box(low=-1, high=1, shape=(2,)),
  283. "local": gym.spaces.Box(low=-1, high=1, shape=(2,)),
  284. }
  285. ),
  286. action_space=gym.spaces.Discrete(2),
  287. model_config_dict={"fcnet_hiddens": [64]},
  288. ),
  289. "local_5d": SingleAgentRLModuleSpec(
  290. observation_space=gym.spaces.Dict(
  291. {
  292. "global": gym.spaces.Box(low=-1, high=1, shape=(2,)),
  293. "local": gym.spaces.Box(low=-1, high=1, shape=(5,)),
  294. }
  295. ),
  296. action_space=gym.spaces.Discrete(5),
  297. model_config_dict={"fcnet_hiddens": [64]},
  298. ),
  299. },
  300. )
  301. module = spec.build()
  302. # __pass-custom-marlmodule-shared-enc-end__