123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- import gymnasium as gym
- import numpy as np
- import shutil
- import tempfile
- import tree
- import unittest
- import ray
- from ray.rllib.algorithms.ppo import PPOConfig
- from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
- from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
- from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
- from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
- from ray.rllib.core.rl_module.marl_module import (
- MultiAgentRLModuleSpec,
- MultiAgentRLModule,
- )
- from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.test_utils import check, framework_iterator
- from ray.rllib.utils.numpy import convert_to_numpy
- PPO_MODULES = {"tf2": PPOTfRLModule, "torch": PPOTorchRLModule}
- NUM_AGENTS = 2
- class TestAlgorithmRLModuleRestore(unittest.TestCase):
- """Test RLModule loading from rl module spec across a local node."""
- def setUp(self) -> None:
- ray.init()
- def tearDown(self) -> None:
- ray.shutdown()
- @staticmethod
- def get_ppo_config(num_agents=NUM_AGENTS):
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- # policy_id is policy_i where i is the agent id
- pol_id = f"policy_{agent_id}"
- return pol_id
- scaling_config = {
- "num_learner_workers": 0,
- "num_gpus_per_learner_worker": 0,
- }
- policies = {f"policy_{i}" for i in range(num_agents)}
- config = (
- PPOConfig()
- .rollouts(rollout_fragment_length=4)
- .environment(MultiAgentCartPole, env_config={"num_agents": num_agents})
- .training(num_sgd_iter=1, train_batch_size=8, sgd_minibatch_size=8)
- .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
- .training(_enable_learner_api=True)
- .resources(**scaling_config)
- )
- return config
- def test_e2e_load_simple_marl_module(self):
- """Test if we can train a PPO algorithm with a checkpointed MARL module e2e."""
- config = self.get_ppo_config()
- env = MultiAgentCartPole({"num_agents": NUM_AGENTS})
- for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
- # create a marl_module to load and save it to a checkpoint directory
- module_specs = {}
- module_class = PPO_MODULES[fw]
- for i in range(NUM_AGENTS):
- module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
- catalog_class=PPOCatalog,
- )
- marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
- marl_module = marl_module_spec.build()
- marl_module_weights = convert_to_numpy(marl_module.get_state())
- marl_checkpoint_path = tempfile.mkdtemp()
- marl_module.save_to_checkpoint(marl_checkpoint_path)
- # create a new MARL_spec with the checkpoint from the previous one
- marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
- module_specs=module_specs,
- load_state_path=marl_checkpoint_path,
- )
- config = config.rl_module(
- rl_module_spec=marl_module_spec_from_checkpoint,
- _enable_rl_module_api=True,
- )
- # create the algorithm with multiple nodes and check if the weights
- # are the same as the original MARL Module
- algo = config.build()
- algo_module_weights = algo.learner_group.get_weights()
- check(algo_module_weights, marl_module_weights)
- algo.train()
- algo.stop()
- del algo
- shutil.rmtree(marl_checkpoint_path)
- def test_e2e_load_complex_marl_module(self):
- """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e."""
- config = self.get_ppo_config()
- env = MultiAgentCartPole({"num_agents": NUM_AGENTS})
- for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
- # create a marl_module to load and save it to a checkpoint directory
- module_specs = {}
- module_class = PPO_MODULES[fw]
- for i in range(NUM_AGENTS):
- module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
- catalog_class=PPOCatalog,
- )
- marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
- marl_module = marl_module_spec.build()
- marl_checkpoint_path = tempfile.mkdtemp()
- marl_module.save_to_checkpoint(marl_checkpoint_path)
- # create a RLModule to load and override the "policy_1" module with
- module_to_swap_in = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [64]},
- catalog_class=PPOCatalog,
- ).build()
- module_to_swap_in_path = tempfile.mkdtemp()
- module_to_swap_in.save_to_checkpoint(module_to_swap_in_path)
- # create a new MARL_spec with the checkpoint from the marl_checkpoint
- # and the module_to_swap_in_checkpoint
- module_specs["policy_1"] = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [64]},
- catalog_class=PPOCatalog,
- load_state_path=module_to_swap_in_path,
- )
- marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
- module_specs=module_specs,
- load_state_path=marl_checkpoint_path,
- )
- config = config.rl_module(
- rl_module_spec=marl_module_spec_from_checkpoint,
- _enable_rl_module_api=True,
- )
- # create the algorithm with multiple nodes and check if the weights
- # are the same as the original MARL Module
- algo = config.build()
- algo_module_weights = algo.learner_group.get_weights()
- marl_module_with_swapped_in_module = MultiAgentRLModule()
- marl_module_with_swapped_in_module.add_module(
- "policy_0", marl_module["policy_0"]
- )
- marl_module_with_swapped_in_module.add_module("policy_1", module_to_swap_in)
- check(
- algo_module_weights,
- convert_to_numpy(marl_module_with_swapped_in_module.get_state()),
- )
- algo.train()
- algo.stop()
- del algo
- shutil.rmtree(marl_checkpoint_path)
- def test_e2e_load_rl_module(self):
- """Test if we can train a PPO algorithm with a cpkt RL module e2e."""
- scaling_config = {
- "num_learner_workers": 0,
- "num_gpus_per_learner_worker": 0,
- }
- config = (
- PPOConfig()
- .rollouts(rollout_fragment_length=4)
- .environment("CartPole-v1")
- .training(num_sgd_iter=1, train_batch_size=8, sgd_minibatch_size=8)
- .training(_enable_learner_api=True)
- .resources(**scaling_config)
- )
- env = gym.make("CartPole-v1")
- for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
- # create a marl_module to load and save it to a checkpoint directory
- module_class = PPO_MODULES[fw]
- module_spec = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [32]},
- catalog_class=PPOCatalog,
- )
- module = module_spec.build()
- module_ckpt_path = tempfile.mkdtemp()
- module.save_to_checkpoint(module_ckpt_path)
- module_to_load_spec = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [32]},
- catalog_class=PPOCatalog,
- load_state_path=module_ckpt_path,
- )
- config = config.rl_module(
- rl_module_spec=module_to_load_spec,
- _enable_rl_module_api=True,
- )
- # create the algorithm with multiple nodes and check if the weights
- # are the same as the original MARL Module
- algo = config.build()
- algo_module_weights = algo.learner_group.get_weights()
- check(
- algo_module_weights[DEFAULT_POLICY_ID],
- convert_to_numpy(module.get_state()),
- )
- algo.train()
- algo.stop()
- del algo
- shutil.rmtree(module_ckpt_path)
- def test_e2e_load_complex_marl_module_with_modules_to_load(self):
- """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e.
- Additionally, check if we can set modules to load so that we can exclude
- a module from our ckpted MARL module from being loaded.
- """
- num_agents = 3
- config = self.get_ppo_config(num_agents=num_agents)
- env = MultiAgentCartPole({"num_agents": num_agents})
- for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
- # create a marl_module to load and save it to a checkpoint directory
- module_specs = {}
- module_class = PPO_MODULES[fw]
- for i in range(num_agents):
- module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [32 * (i + 1)]},
- catalog_class=PPOCatalog,
- )
- marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
- marl_module = marl_module_spec.build()
- marl_checkpoint_path = tempfile.mkdtemp()
- marl_module.save_to_checkpoint(marl_checkpoint_path)
- # create a RLModule to load and override the "policy_1" module with
- module_to_swap_in = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [64]},
- catalog_class=PPOCatalog,
- ).build()
- module_to_swap_in_path = tempfile.mkdtemp()
- module_to_swap_in.save_to_checkpoint(module_to_swap_in_path)
- # create a new MARL_spec with the checkpoint from the marl_checkpoint
- # and the module_to_swap_in_checkpoint
- module_specs["policy_1"] = SingleAgentRLModuleSpec(
- module_class=module_class,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config_dict={"fcnet_hiddens": [64]},
- catalog_class=PPOCatalog,
- load_state_path=module_to_swap_in_path,
- )
- marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec(
- module_specs=module_specs,
- load_state_path=marl_checkpoint_path,
- modules_to_load={
- "policy_0",
- },
- )
- config = config.rl_module(
- rl_module_spec=marl_module_spec_from_checkpoint,
- _enable_rl_module_api=True,
- )
- # create the algorithm with multiple nodes and check if the weights
- # are the same as the original MARL Module
- algo = config.build()
- algo_module_weights = algo.learner_group.get_weights()
- # weights of "policy_0" should be the same as in the loaded marl module
- # since we specified it as being apart of the modules_to_load
- check(
- algo_module_weights["policy_0"],
- convert_to_numpy(marl_module["policy_0"].get_state()),
- )
- # weights of "policy_1" should be the same as in the module_to_swap_in since
- # we specified its load path separately in an rl_module_spec inside of the
- # marl_module_spec_from_checkpoint
- check(
- algo_module_weights["policy_1"],
- convert_to_numpy(module_to_swap_in.get_state()),
- )
- # weights of "policy_2" should be different from the loaded marl module
- # since we didn't specify it as being apart of the modules_to_load
- policy_2_algo_module_weight_sum = np.sum(
- [
- np.sum(s)
- for s in tree.flatten(
- convert_to_numpy(algo_module_weights["policy_2"])
- )
- ]
- )
- policy_2_marl_module_weight_sum = np.sum(
- [
- np.sum(s)
- for s in tree.flatten(
- convert_to_numpy(marl_module["policy_2"].get_state())
- )
- ]
- )
- check(
- policy_2_algo_module_weight_sum,
- policy_2_marl_module_weight_sum,
- false=True,
- )
- algo.train()
- algo.stop()
- del algo
- shutil.rmtree(marl_checkpoint_path)
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|