marl_module.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. from dataclasses import dataclass, field
  2. import pathlib
  3. import pprint
  4. from typing import Any, Dict, KeysView, Mapping, Optional, Set, Type, Union
  5. from ray.util.annotations import PublicAPI
  6. from ray.rllib.utils.annotations import override, ExperimentalAPI
  7. from ray.rllib.utils.nested_dict import NestedDict
  8. from ray.rllib.core.models.specs.typing import SpecType
  9. from ray.rllib.policy.sample_batch import MultiAgentBatch
  10. from ray.rllib.core.rl_module.rl_module import (
  11. RLModule,
  12. RLMODULE_METADATA_FILE_NAME,
  13. RLMODULE_STATE_DIR_NAME,
  14. SingleAgentRLModuleSpec,
  15. )
  16. # TODO (Kourosh): change this to module_id later to enforce consistency
  17. from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
  18. from ray.rllib.utils.policy import validate_policy_id
  19. from ray.rllib.utils.serialization import serialize_type, deserialize_type
  20. ModuleID = str
  21. @PublicAPI(stability="alpha")
  22. class MultiAgentRLModule(RLModule):
  23. """Base class for multi-agent RLModules.
  24. This class holds a mapping from module_ids to the underlying RLModules. It provides
  25. a convenient way of accessing each individual module, as well as accessing all of
  26. them with only one API call. Whether or not a given module is trainable is
  27. determined by the caller of this class (not the instance of this class itself).
  28. The extension of this class can include any arbitrary neural networks as part of
  29. the multi-agent module. For example, a multi-agent module can include a shared
  30. encoder network that is used by all the individual RLModules. It is up to the user
  31. to decide how to implement this class.
  32. The default implementation assumes the data communicated as input and output of
  33. the APIs in this class are `MultiAgentBatch` types. The `MultiAgentRLModule` simply
  34. loops through each `module_id`, and runs the forward pass of the corresponding
  35. `RLModule` object with the associated `SampleBatch` within the `MultiAgentBatch`.
  36. It also assumes that the underlying RLModules do not share any parameters or
  37. communication with one another. The behavior of modules with such advanced
  38. communication would be undefined by default. To share parameters or communication
  39. between the underlying RLModules, you should implement your own
  40. `MultiAgentRLModule` subclass.
  41. """
  42. def __init__(self, config: Optional["MultiAgentRLModuleConfig"] = None) -> None:
  43. """Initializes a MultiagentRLModule instance.
  44. Args:
  45. config: The MultiAgentRLModuleConfig to use.
  46. """
  47. super().__init__(config or MultiAgentRLModuleConfig())
  48. def setup(self):
  49. """Sets up the underlying RLModules."""
  50. self._rl_modules = {}
  51. self.__check_module_configs(self.config.modules)
  52. for module_id, module_spec in self.config.modules.items():
  53. self._rl_modules[module_id] = module_spec.build()
  54. @classmethod
  55. def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]):
  56. """Checks the module configs for validity.
  57. The module_configs be a mapping from module_ids to SingleAgentRLModuleSpec
  58. objects.
  59. Args:
  60. module_configs: The module configs to check.
  61. Raises:
  62. ValueError: If the module configs are invalid.
  63. """
  64. for module_id, module_spec in module_configs.items():
  65. if not isinstance(module_spec, SingleAgentRLModuleSpec):
  66. raise ValueError(
  67. f"Module {module_id} is not a SingleAgentRLModuleSpec object."
  68. )
  69. def keys(self) -> KeysView[ModuleID]:
  70. """Returns a keys view over the module IDs in this MultiAgentRLModule."""
  71. return self._rl_modules.keys()
  72. @override(RLModule)
  73. def as_multi_agent(self) -> "MultiAgentRLModule":
  74. """Returns a multi-agent wrapper around this module.
  75. This method is overridden to avoid double wrapping.
  76. Returns:
  77. The instance itself.
  78. """
  79. return self
  80. def add_module(
  81. self,
  82. module_id: ModuleID,
  83. module: RLModule,
  84. *,
  85. override: bool = False,
  86. ) -> None:
  87. """Adds a module at run time to the multi-agent module.
  88. Args:
  89. module_id: The module ID to add. If the module ID already exists and
  90. override is False, an error is raised. If override is True, the module
  91. is replaced.
  92. module: The module to add.
  93. override: Whether to override the module if it already exists.
  94. Raises:
  95. ValueError: If the module ID already exists and override is False.
  96. Warnings are raised if the module id is not valid according to the logic of
  97. validate_policy_id().
  98. """
  99. validate_policy_id(module_id)
  100. if module_id in self._rl_modules and not override:
  101. raise ValueError(
  102. f"Module ID {module_id} already exists. If your intention is to "
  103. "override, set override=True."
  104. )
  105. self._rl_modules[module_id] = module
  106. def remove_module(
  107. self, module_id: ModuleID, *, raise_err_if_not_found: bool = True
  108. ) -> None:
  109. """Removes a module at run time from the multi-agent module.
  110. Args:
  111. module_id: The module ID to remove.
  112. raise_err_if_not_found: Whether to raise an error if the module ID is not
  113. found.
  114. Raises:
  115. ValueError: If the module ID does not exist and raise_err_if_not_found is
  116. True.
  117. """
  118. if raise_err_if_not_found:
  119. self._check_module_exists(module_id)
  120. del self._rl_modules[module_id]
  121. def __getitem__(self, module_id: ModuleID) -> RLModule:
  122. """Returns the module with the given module ID.
  123. Args:
  124. module_id: The module ID to get.
  125. Returns:
  126. The module with the given module ID.
  127. """
  128. self._check_module_exists(module_id)
  129. return self._rl_modules[module_id]
  130. @override(RLModule)
  131. def output_specs_train(self) -> SpecType:
  132. return []
  133. @override(RLModule)
  134. def output_specs_inference(self) -> SpecType:
  135. return []
  136. @override(RLModule)
  137. def output_specs_exploration(self) -> SpecType:
  138. return []
  139. @override(RLModule)
  140. def _default_input_specs(self) -> SpecType:
  141. """Multi-agent RLModule should not check the input specs.
  142. The underlying single-agent RLModules will check the input specs.
  143. """
  144. return []
  145. @override(RLModule)
  146. def _forward_train(
  147. self, batch: MultiAgentBatch, **kwargs
  148. ) -> Union[Mapping[str, Any], Dict[ModuleID, Mapping[str, Any]]]:
  149. """Runs the forward_train pass.
  150. TODO(avnishn, kourosh): Review type hints for forward methods.
  151. Args:
  152. batch: The batch of multi-agent data (i.e. mapping from module ids to
  153. SampleBaches).
  154. Returns:
  155. The output of the forward_train pass the specified modules.
  156. """
  157. return self._run_forward_pass("forward_train", batch, **kwargs)
  158. @override(RLModule)
  159. def _forward_inference(
  160. self, batch: MultiAgentBatch, **kwargs
  161. ) -> Union[Mapping[str, Any], Dict[ModuleID, Mapping[str, Any]]]:
  162. """Runs the forward_inference pass.
  163. TODO(avnishn, kourosh): Review type hints for forward methods.
  164. Args:
  165. batch: The batch of multi-agent data (i.e. mapping from module ids to
  166. SampleBaches).
  167. Returns:
  168. The output of the forward_inference pass the specified modules.
  169. """
  170. return self._run_forward_pass("forward_inference", batch, **kwargs)
  171. @override(RLModule)
  172. def _forward_exploration(
  173. self, batch: MultiAgentBatch, **kwargs
  174. ) -> Union[Mapping[str, Any], Dict[ModuleID, Mapping[str, Any]]]:
  175. """Runs the forward_exploration pass.
  176. TODO(avnishn, kourosh): Review type hints for forward methods.
  177. Args:
  178. batch: The batch of multi-agent data (i.e. mapping from module ids to
  179. SampleBaches).
  180. Returns:
  181. The output of the forward_exploration pass the specified modules.
  182. """
  183. return self._run_forward_pass("forward_exploration", batch, **kwargs)
  184. @override(RLModule)
  185. def get_state(
  186. self, module_ids: Optional[Set[ModuleID]] = None
  187. ) -> Mapping[ModuleID, Any]:
  188. """Returns the state of the multi-agent module.
  189. This method returns the state of each module specified by module_ids. If
  190. module_ids is None, the state of all modules is returned.
  191. Args:
  192. module_ids: The module IDs to get the state of. If None, the state of all
  193. modules is returned.
  194. Returns:
  195. A nested state dict with the first layer being the module ID and the second
  196. is the state of the module. The returned dict values are framework-specific
  197. tensors.
  198. """
  199. if module_ids is None:
  200. module_ids = self._rl_modules.keys()
  201. return {
  202. module_id: self._rl_modules[module_id].get_state()
  203. for module_id in module_ids
  204. }
  205. @override(RLModule)
  206. def set_state(self, state_dict: Mapping[ModuleID, Any]) -> None:
  207. """Sets the state of the multi-agent module.
  208. It is assumed that the state_dict is a mapping from module IDs to their
  209. corressponding state. This method sets the state of each module by calling
  210. their set_state method. If you want to set the state of some of the RLModules
  211. within this MultiAgentRLModule your state_dict can only include the state of
  212. those RLModules. Override this method to customize the state_dict for custom
  213. more advanced multi-agent use cases.
  214. Args:
  215. state_dict: The state dict to set.
  216. """
  217. for module_id, state in state_dict.items():
  218. self._rl_modules[module_id].set_state(state)
  219. @override(RLModule)
  220. def save_state(self, path: Union[str, pathlib.Path]) -> None:
  221. """Saves the weights of this MultiAgentRLModule to dir.
  222. Args:
  223. path: The path to the directory to save the checkpoint to.
  224. """
  225. path = pathlib.Path(path)
  226. path.mkdir(parents=True, exist_ok=True)
  227. for module_id, module in self._rl_modules.items():
  228. module.save_to_checkpoint(str(path / module_id))
  229. @override(RLModule)
  230. def load_state(
  231. self,
  232. path: Union[str, pathlib.Path],
  233. modules_to_load: Optional[Set[ModuleID]] = None,
  234. ) -> None:
  235. """Loads the weights of an MultiAgentRLModule from dir.
  236. NOTE:
  237. If you want to load a module that is not already
  238. in this MultiAgentRLModule, you should add it to this MultiAgentRLModule
  239. before loading the checkpoint.
  240. Args:
  241. path: The path to the directory to load the state from.
  242. modules_to_load: The modules whose state is to be loaded from the path. If
  243. this is None, all modules that are checkpointed will be loaded into this
  244. marl module.
  245. """
  246. path = pathlib.Path(path)
  247. if not modules_to_load:
  248. modules_to_load = set(self._rl_modules.keys())
  249. path.mkdir(parents=True, exist_ok=True)
  250. for submodule_id in modules_to_load:
  251. if submodule_id not in self._rl_modules:
  252. raise ValueError(
  253. f"Module {submodule_id} from `modules_to_load`: "
  254. f"{modules_to_load} not found in this MultiAgentRLModule."
  255. )
  256. submodule = self._rl_modules[submodule_id]
  257. submodule_weights_dir = path / submodule_id / RLMODULE_STATE_DIR_NAME
  258. if not submodule_weights_dir.exists():
  259. raise ValueError(
  260. f"Submodule {submodule_id}'s module state directory: "
  261. f"{submodule_weights_dir} not found in checkpoint dir {path}."
  262. )
  263. submodule.load_state(submodule_weights_dir)
  264. @override(RLModule)
  265. def save_to_checkpoint(self, checkpoint_dir_path: Union[str, pathlib.Path]) -> None:
  266. path = pathlib.Path(checkpoint_dir_path)
  267. path.mkdir(parents=True, exist_ok=True)
  268. self.save_state(path)
  269. self._save_module_metadata(path, MultiAgentRLModuleSpec)
  270. @classmethod
  271. @override(RLModule)
  272. def from_checkpoint(cls, checkpoint_dir_path: Union[str, pathlib.Path]) -> None:
  273. path = pathlib.Path(checkpoint_dir_path)
  274. metadata_path = path / RLMODULE_METADATA_FILE_NAME
  275. marl_module = cls._from_metadata_file(metadata_path)
  276. marl_module.load_state(path)
  277. return marl_module
  278. def __repr__(self) -> str:
  279. return f"MARL({pprint.pformat(self._rl_modules)})"
  280. def _run_forward_pass(
  281. self,
  282. forward_fn_name: str,
  283. batch: NestedDict[Any],
  284. **kwargs,
  285. ) -> Dict[ModuleID, Mapping[ModuleID, Any]]:
  286. """This is a helper method that runs the forward pass for the given module.
  287. It uses forward_fn_name to get the forward pass method from the RLModule
  288. (e.g. forward_train vs. forward_exploration) and runs it on the given batch.
  289. Args:
  290. forward_fn_name: The name of the forward pass method to run.
  291. batch: The batch of multi-agent data (i.e. mapping from module ids to
  292. SampleBaches).
  293. **kwargs: Additional keyword arguments to pass to the forward function.
  294. Returns:
  295. The output of the forward pass the specified modules. The output is a
  296. mapping from module ID to the output of the forward pass.
  297. """
  298. module_ids = list(batch.shallow_keys())
  299. for module_id in module_ids:
  300. self._check_module_exists(module_id)
  301. outputs = {}
  302. for module_id in module_ids:
  303. rl_module = self._rl_modules[module_id]
  304. forward_fn = getattr(rl_module, forward_fn_name)
  305. outputs[module_id] = forward_fn(batch[module_id], **kwargs)
  306. return outputs
  307. def _check_module_exists(self, module_id: ModuleID) -> None:
  308. if module_id not in self._rl_modules:
  309. raise KeyError(
  310. f"Module with module_id {module_id} not found. "
  311. f"Available modules: {set(self.keys())}"
  312. )
  313. @PublicAPI(stability="alpha")
  314. @dataclass
  315. class MultiAgentRLModuleSpec:
  316. """A utility spec class to make it constructing MARL modules easier.
  317. Users can extend this class to modify the behavior of base class. For example to
  318. share neural networks across the modules, the build method can be overriden to
  319. create the shared module first and then pass it to custom module classes that would
  320. then use it as a shared module.
  321. Args:
  322. marl_module_class: The class of the multi-agent RLModule to construct. By
  323. default it is set to MultiAgentRLModule class. This class simply loops
  324. throught each module and calls their foward methods.
  325. module_specs: The module specs for each individual module. It can be either a
  326. SingleAgentRLModuleSpec used for all module_ids or a dictionary mapping
  327. from module IDs to SingleAgentRLModuleSpecs for each individual module.
  328. load_state_path: The path to the module state to load from. NOTE: This must be
  329. an absolute path. NOTE: If the load_state_path of this spec is set, and
  330. the load_state_path of one of the SingleAgentRLModuleSpecs' is also set,
  331. the weights of that RL Module will be loaded from the path specified in
  332. the SingleAgentRLModuleSpec. This is useful if you want to load the weights
  333. of a MARL module and also manually load the weights of some of the RL
  334. modules within that MARL module from other checkpoints.
  335. modules_to_load: A set of module ids to load from the checkpoint. This is
  336. only used if load_state_path is set. If this is None, all modules are
  337. loaded.
  338. """
  339. marl_module_class: Type[MultiAgentRLModule] = MultiAgentRLModule
  340. module_specs: Union[
  341. SingleAgentRLModuleSpec, Dict[ModuleID, SingleAgentRLModuleSpec]
  342. ] = None
  343. load_state_path: Optional[str] = None
  344. modules_to_load: Optional[Set[ModuleID]] = None
  345. def __post_init__(self):
  346. if self.module_specs is None:
  347. raise ValueError(
  348. "Module_specs cannot be None. It should be either a "
  349. "SingleAgentRLModuleSpec or a dictionary mapping from module IDs to "
  350. "SingleAgentRLModuleSpecs for each individual module."
  351. )
  352. def get_marl_config(self) -> "MultiAgentRLModuleConfig":
  353. """Returns the MultiAgentRLModuleConfig for this spec."""
  354. return MultiAgentRLModuleConfig(modules=self.module_specs)
  355. @OverrideToImplementCustomLogic
  356. def build(
  357. self, module_id: Optional[ModuleID] = None
  358. ) -> Union[SingleAgentRLModuleSpec, "MultiAgentRLModule"]:
  359. """Builds either the multi-agent module or the single-agent module.
  360. If module_id is None, it builds the multi-agent module. Otherwise, it builds
  361. the single-agent module with the given module_id.
  362. Note: If when build is called the module_specs is not a dictionary, it will
  363. raise an error, since it should have been updated by the caller to inform us
  364. about the module_ids.
  365. Args:
  366. module_id: The module_id of the single-agent module to build. If None, it
  367. builds the multi-agent module.
  368. Returns:
  369. The built module. If module_id is None, it returns the multi-agent module.
  370. """
  371. self._check_before_build()
  372. if module_id:
  373. return self.module_specs[module_id].build()
  374. module_config = self.get_marl_config()
  375. module = self.marl_module_class(module_config)
  376. return module
  377. def add_modules(
  378. self,
  379. module_specs: Dict[ModuleID, SingleAgentRLModuleSpec],
  380. overwrite: bool = True,
  381. ) -> None:
  382. """Add new module specs to the spec or updates existing ones.
  383. Args:
  384. module_specs: The mapping for the module_id to the single-agent module
  385. specs to be added to this multi-agent module spec.
  386. overwrite: Whether to overwrite the existing module specs if they already
  387. exist. If False, they will be updated only.
  388. """
  389. if self.module_specs is None:
  390. self.module_specs = {}
  391. for module_id, module_spec in module_specs.items():
  392. if overwrite or module_id not in self.module_specs:
  393. self.module_specs[module_id] = module_spec
  394. else:
  395. self.module_specs[module_id].update(module_spec)
  396. @classmethod
  397. def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec":
  398. """Creates a MultiAgentRLModuleSpec from a MultiAgentRLModule.
  399. Args:
  400. module: The MultiAgentRLModule to create the spec from.
  401. Returns:
  402. The MultiAgentRLModuleSpec.
  403. """
  404. # we want to get the spec of the underlying unwrapped module that way we can
  405. # easily reconstruct it. The only wrappers that we expect to support today are
  406. # wrappers that allow us to do distributed training. Those will be added back
  407. # by the learner if necessary.
  408. module_specs = {
  409. module_id: SingleAgentRLModuleSpec.from_module(rl_module.unwrapped())
  410. for module_id, rl_module in module._rl_modules.items()
  411. }
  412. marl_module_class = module.__class__
  413. return MultiAgentRLModuleSpec(
  414. marl_module_class=marl_module_class, module_specs=module_specs
  415. )
  416. def _check_before_build(self):
  417. if not isinstance(self.module_specs, dict):
  418. raise ValueError(
  419. f"When build() is called on {self.__class__}, the module_specs "
  420. "should be a dictionary mapping from module IDs to "
  421. "SingleAgentRLModuleSpecs for each individual module."
  422. )
  423. def to_dict(self) -> Dict[str, Any]:
  424. """Converts the MultiAgentRLModuleSpec to a dictionary."""
  425. return {
  426. "marl_module_class": serialize_type(self.marl_module_class),
  427. "module_specs": {
  428. module_id: module_spec.to_dict()
  429. for module_id, module_spec in self.module_specs.items()
  430. },
  431. }
  432. @classmethod
  433. def from_dict(cls, d) -> "MultiAgentRLModuleSpec":
  434. """Creates a MultiAgentRLModuleSpec from a dictionary."""
  435. return MultiAgentRLModuleSpec(
  436. marl_module_class=deserialize_type(d["marl_module_class"]),
  437. module_specs={
  438. module_id: SingleAgentRLModuleSpec.from_dict(module_spec)
  439. for module_id, module_spec in d["module_specs"].items()
  440. },
  441. )
  442. def update(self, other: "MultiAgentRLModuleSpec", overwrite=False) -> None:
  443. """Updates this spec with the other spec.
  444. Traverses this MultiAgentRLModuleSpec's module_specs and updates them with
  445. the module specs from the other MultiAgentRLModuleSpec.
  446. Args:
  447. other: The other spec to update this spec with.
  448. overwrite: Whether to overwrite the existing module specs if they already
  449. exist. If False, they will be updated only.
  450. """
  451. assert type(other) is MultiAgentRLModuleSpec
  452. if isinstance(other.module_specs, dict):
  453. self.add_modules(other.module_specs, overwrite=overwrite)
  454. else:
  455. if not self.module_specs:
  456. self.module_specs = other.module_specs
  457. else:
  458. self.module_specs.update(other.module_specs)
  459. @ExperimentalAPI
  460. @dataclass
  461. class MultiAgentRLModuleConfig:
  462. modules: Mapping[ModuleID, SingleAgentRLModuleSpec] = field(default_factory=dict)
  463. def to_dict(self):
  464. return {
  465. "modules": {
  466. module_id: module_spec.to_dict()
  467. for module_id, module_spec in self.modules.items()
  468. }
  469. }
  470. @classmethod
  471. def from_dict(cls, d) -> "MultiAgentRLModuleConfig":
  472. return cls(
  473. modules={
  474. module_id: SingleAgentRLModuleSpec.from_dict(module_spec)
  475. for module_id, module_spec in d["modules"].items()
  476. }
  477. )