test_algorithm_rl_module_restore.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import gymnasium as gym
  2. import numpy as np
  3. import shutil
  4. import tempfile
  5. import tree
  6. import unittest
  7. import ray
  8. from ray.rllib.algorithms.ppo import PPOConfig
  9. from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
  10. from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
  11. from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
  12. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  13. from ray.rllib.core.rl_module.marl_module import (
  14. MultiAgentRLModuleSpec,
  15. MultiAgentRLModule,
  16. )
  17. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  18. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  19. from ray.rllib.utils.test_utils import check, framework_iterator
  20. from ray.rllib.utils.numpy import convert_to_numpy
  21. PPO_MODULES = {"tf2": PPOTfRLModule, "torch": PPOTorchRLModule}
  22. NUM_AGENTS = 2
  23. class TestAlgorithmRLModuleRestore(unittest.TestCase):
  24. """Test RLModule loading from rl module spec across a local node."""
  25. def setUp(self) -> None:
  26. ray.init()
  27. def tearDown(self) -> None:
  28. ray.shutdown()
  29. @staticmethod
  30. def get_ppo_config(num_agents=NUM_AGENTS):
  31. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  32. # policy_id is policy_i where i is the agent id
  33. pol_id = f"policy_{agent_id}"
  34. return pol_id
  35. scaling_config = {
  36. "num_learner_workers": 0,
  37. "num_gpus_per_learner_worker": 0,
  38. }
  39. policies = {f"policy_{i}" for i in range(num_agents)}
  40. config = (
  41. PPOConfig()
  42. .rollouts(rollout_fragment_length=4)
  43. .environment(MultiAgentCartPole, env_config={"num_agents": num_agents})
  44. .training(num_sgd_iter=1, train_batch_size=8, sgd_minibatch_size=8)
  45. .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
  46. .training(_enable_learner_api=True)
  47. .resources(**scaling_config)
  48. )
  49. return config
  50. def test_e2e_load_simple_marl_module(self):
  51. """Test if we can train a PPO algorithm with a checkpointed MARL module e2e."""
  52. config = self.get_ppo_config()
  53. env = MultiAgentCartPole({"num_agents": NUM_AGENTS})
  54. for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
  55. # create a marl_module to load and save it to a checkpoint directory
  56. module_specs = {}
  57. module_class = PPO_MODULES[fw]
  58. for i in range(NUM_AGENTS):
  59. module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
  60. module_class=module_class,
  61. observation_space=env.observation_space,
  62. action_space=env.action_space,
  63. model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
  64. catalog_class=PPOCatalog,
  65. )
  66. marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
  67. marl_module = marl_module_spec.build()
  68. marl_module_weights = convert_to_numpy(marl_module.get_state())
  69. marl_checkpoint_path = tempfile.mkdtemp()
  70. marl_module.save_to_checkpoint(marl_checkpoint_path)
  71. # create a new MARL_spec with the checkpoint from the previous one
  72. marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
  73. module_specs=module_specs,
  74. load_state_path=marl_checkpoint_path,
  75. )
  76. config = config.rl_module(
  77. rl_module_spec=marl_module_spec_from_checkpoint,
  78. _enable_rl_module_api=True,
  79. )
  80. # create the algorithm with multiple nodes and check if the weights
  81. # are the same as the original MARL Module
  82. algo = config.build()
  83. algo_module_weights = algo.learner_group.get_weights()
  84. check(algo_module_weights, marl_module_weights)
  85. algo.train()
  86. algo.stop()
  87. del algo
  88. shutil.rmtree(marl_checkpoint_path)
  89. def test_e2e_load_complex_marl_module(self):
  90. """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e."""
  91. config = self.get_ppo_config()
  92. env = MultiAgentCartPole({"num_agents": NUM_AGENTS})
  93. for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
  94. # create a marl_module to load and save it to a checkpoint directory
  95. module_specs = {}
  96. module_class = PPO_MODULES[fw]
  97. for i in range(NUM_AGENTS):
  98. module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
  99. module_class=module_class,
  100. observation_space=env.observation_space,
  101. action_space=env.action_space,
  102. model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
  103. catalog_class=PPOCatalog,
  104. )
  105. marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
  106. marl_module = marl_module_spec.build()
  107. marl_checkpoint_path = tempfile.mkdtemp()
  108. marl_module.save_to_checkpoint(marl_checkpoint_path)
  109. # create a RLModule to load and override the "policy_1" module with
  110. module_to_swap_in = SingleAgentRLModuleSpec(
  111. module_class=module_class,
  112. observation_space=env.observation_space,
  113. action_space=env.action_space,
  114. model_config_dict={"fcnet_hiddens": [64]},
  115. catalog_class=PPOCatalog,
  116. ).build()
  117. module_to_swap_in_path = tempfile.mkdtemp()
  118. module_to_swap_in.save_to_checkpoint(module_to_swap_in_path)
  119. # create a new MARL_spec with the checkpoint from the marl_checkpoint
  120. # and the module_to_swap_in_checkpoint
  121. module_specs["policy_1"] = SingleAgentRLModuleSpec(
  122. module_class=module_class,
  123. observation_space=env.observation_space,
  124. action_space=env.action_space,
  125. model_config_dict={"fcnet_hiddens": [64]},
  126. catalog_class=PPOCatalog,
  127. load_state_path=module_to_swap_in_path,
  128. )
  129. marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
  130. module_specs=module_specs,
  131. load_state_path=marl_checkpoint_path,
  132. )
  133. config = config.rl_module(
  134. rl_module_spec=marl_module_spec_from_checkpoint,
  135. _enable_rl_module_api=True,
  136. )
  137. # create the algorithm with multiple nodes and check if the weights
  138. # are the same as the original MARL Module
  139. algo = config.build()
  140. algo_module_weights = algo.learner_group.get_weights()
  141. marl_module_with_swapped_in_module = MultiAgentRLModule()
  142. marl_module_with_swapped_in_module.add_module(
  143. "policy_0", marl_module["policy_0"]
  144. )
  145. marl_module_with_swapped_in_module.add_module("policy_1", module_to_swap_in)
  146. check(
  147. algo_module_weights,
  148. convert_to_numpy(marl_module_with_swapped_in_module.get_state()),
  149. )
  150. algo.train()
  151. algo.stop()
  152. del algo
  153. shutil.rmtree(marl_checkpoint_path)
  154. def test_e2e_load_rl_module(self):
  155. """Test if we can train a PPO algorithm with a cpkt RL module e2e."""
  156. scaling_config = {
  157. "num_learner_workers": 0,
  158. "num_gpus_per_learner_worker": 0,
  159. }
  160. config = (
  161. PPOConfig()
  162. .rollouts(rollout_fragment_length=4)
  163. .environment("CartPole-v1")
  164. .training(num_sgd_iter=1, train_batch_size=8, sgd_minibatch_size=8)
  165. .training(_enable_learner_api=True)
  166. .resources(**scaling_config)
  167. )
  168. env = gym.make("CartPole-v1")
  169. for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
  170. # create a marl_module to load and save it to a checkpoint directory
  171. module_class = PPO_MODULES[fw]
  172. module_spec = SingleAgentRLModuleSpec(
  173. module_class=module_class,
  174. observation_space=env.observation_space,
  175. action_space=env.action_space,
  176. model_config_dict={"fcnet_hiddens": [32]},
  177. catalog_class=PPOCatalog,
  178. )
  179. module = module_spec.build()
  180. module_ckpt_path = tempfile.mkdtemp()
  181. module.save_to_checkpoint(module_ckpt_path)
  182. module_to_load_spec = SingleAgentRLModuleSpec(
  183. module_class=module_class,
  184. observation_space=env.observation_space,
  185. action_space=env.action_space,
  186. model_config_dict={"fcnet_hiddens": [32]},
  187. catalog_class=PPOCatalog,
  188. load_state_path=module_ckpt_path,
  189. )
  190. config = config.rl_module(
  191. rl_module_spec=module_to_load_spec,
  192. _enable_rl_module_api=True,
  193. )
  194. # create the algorithm with multiple nodes and check if the weights
  195. # are the same as the original MARL Module
  196. algo = config.build()
  197. algo_module_weights = algo.learner_group.get_weights()
  198. check(
  199. algo_module_weights[DEFAULT_POLICY_ID],
  200. convert_to_numpy(module.get_state()),
  201. )
  202. algo.train()
  203. algo.stop()
  204. del algo
  205. shutil.rmtree(module_ckpt_path)
  206. def test_e2e_load_complex_marl_module_with_modules_to_load(self):
  207. """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e.
  208. Additionally, check if we can set modules to load so that we can exclude
  209. a module from our ckpted MARL module from being loaded.
  210. """
  211. num_agents = 3
  212. config = self.get_ppo_config(num_agents=num_agents)
  213. env = MultiAgentCartPole({"num_agents": num_agents})
  214. for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
  215. # create a marl_module to load and save it to a checkpoint directory
  216. module_specs = {}
  217. module_class = PPO_MODULES[fw]
  218. for i in range(num_agents):
  219. module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
  220. module_class=module_class,
  221. observation_space=env.observation_space,
  222. action_space=env.action_space,
  223. model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
  224. catalog_class=PPOCatalog,
  225. )
  226. marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
  227. marl_module = marl_module_spec.build()
  228. marl_checkpoint_path = tempfile.mkdtemp()
  229. marl_module.save_to_checkpoint(marl_checkpoint_path)
  230. # create a RLModule to load and override the "policy_1" module with
  231. module_to_swap_in = SingleAgentRLModuleSpec(
  232. module_class=module_class,
  233. observation_space=env.observation_space,
  234. action_space=env.action_space,
  235. model_config_dict={"fcnet_hiddens": [64]},
  236. catalog_class=PPOCatalog,
  237. ).build()
  238. module_to_swap_in_path = tempfile.mkdtemp()
  239. module_to_swap_in.save_to_checkpoint(module_to_swap_in_path)
  240. # create a new MARL_spec with the checkpoint from the marl_checkpoint
  241. # and the module_to_swap_in_checkpoint
  242. module_specs["policy_1"] = SingleAgentRLModuleSpec(
  243. module_class=module_class,
  244. observation_space=env.observation_space,
  245. action_space=env.action_space,
  246. model_config_dict={"fcnet_hiddens": [64]},
  247. catalog_class=PPOCatalog,
  248. load_state_path=module_to_swap_in_path,
  249. )
  250. marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
  251. module_specs=module_specs,
  252. load_state_path=marl_checkpoint_path,
  253. modules_to_load={
  254. "policy_0",
  255. },
  256. )
  257. config = config.rl_module(
  258. rl_module_spec=marl_module_spec_from_checkpoint,
  259. _enable_rl_module_api=True,
  260. )
  261. # create the algorithm with multiple nodes and check if the weights
  262. # are the same as the original MARL Module
  263. algo = config.build()
  264. algo_module_weights = algo.learner_group.get_weights()
  265. # weights of "policy_0" should be the same as in the loaded marl module
  266. # since we specified it as being apart of the modules_to_load
  267. check(
  268. algo_module_weights["policy_0"],
  269. convert_to_numpy(marl_module["policy_0"].get_state()),
  270. )
  271. # weights of "policy_1" should be the same as in the module_to_swap_in since
  272. # we specified its load path separately in an rl_module_spec inside of the
  273. # marl_module_spec_from_checkpoint
  274. check(
  275. algo_module_weights["policy_1"],
  276. convert_to_numpy(module_to_swap_in.get_state()),
  277. )
  278. # weights of "policy_2" should be different from the loaded marl module
  279. # since we didn't specify it as being apart of the modules_to_load
  280. policy_2_algo_module_weight_sum = np.sum(
  281. [
  282. np.sum(s)
  283. for s in tree.flatten(
  284. convert_to_numpy(algo_module_weights["policy_2"])
  285. )
  286. ]
  287. )
  288. policy_2_marl_module_weight_sum = np.sum(
  289. [
  290. np.sum(s)
  291. for s in tree.flatten(
  292. convert_to_numpy(marl_module["policy_2"].get_state())
  293. )
  294. ]
  295. )
  296. check(
  297. policy_2_algo_module_weight_sum,
  298. policy_2_marl_module_weight_sum,
  299. false=True,
  300. )
  301. algo.train()
  302. algo.stop()
  303. del algo
  304. shutil.rmtree(marl_checkpoint_path)
  305. if __name__ == "__main__":
  306. import pytest
  307. import sys
  308. sys.exit(pytest.main(["-v", __file__]))