123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import gymnasium as gym
- import itertools
- import numpy as np
- import tempfile
- import unittest
- import ray
- from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
- from ray.rllib.core.testing.utils import get_learner_group
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.test_utils import check
- FAKE_BATCH = {
- SampleBatch.OBS: np.array(
- [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
- dtype=np.float32,
- ),
- SampleBatch.NEXT_OBS: np.array(
- [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
- dtype=np.float32,
- ),
- SampleBatch.ACTIONS: np.array([0, 1, 1]),
- SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
- SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
- SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
- SampleBatch.TERMINATEDS: np.array([False, False, True]),
- SampleBatch.TRUNCATEDS: np.array([False, False, False]),
- SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
- SampleBatch.ACTION_DIST_INPUTS: np.array(
- [[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
- ),
- SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
- SampleBatch.EPS_ID: np.array([0, 0, 0]),
- SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
- }
- REMOTE_SCALING_CONFIGS = {
- "remote-cpu": LearnerGroupScalingConfig(num_workers=1),
- "remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1),
- "multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1),
- "multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2),
- # "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig(
- # num_workers=2, num_gpus_per_worker=2
- # ),
- }
- class TestLearnerGroupCheckpointing(unittest.TestCase):
- def setUp(self) -> None:
- ray.init()
- def tearDown(self) -> None:
- ray.shutdown()
- def test_save_load_state(self):
- fws = ["tf2", "torch"]
- scaling_modes = REMOTE_SCALING_CONFIGS.keys()
- test_iterator = itertools.product(fws, scaling_modes)
- batch = SampleBatch(FAKE_BATCH)
- for fw, scaling_mode in test_iterator:
- print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
- env = gym.make("CartPole-v1")
- scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
- initial_learner_group = get_learner_group(
- fw, env, scaling_config, eager_tracing=True
- )
- # checkpoint the initial learner state for later comparison
- initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
- initial_learner_group.save_state(initial_learner_checkpoint_dir)
- initial_learner_group_weights = initial_learner_group.get_weights()
- # do a single update
- initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)
- # checkpoint the learner state after 1 update for later comparison
- learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
- initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)
- # remove that learner, construct a new one, and load the state of the old
- # learner into the new one
- initial_learner_group.shutdown()
- del initial_learner_group
- new_learner_group = get_learner_group(
- fw, env, scaling_config, eager_tracing=True
- )
- new_learner_group.load_state(learner_after_1_update_checkpoint_dir)
- # do another update
- results_with_break = new_learner_group.update(
- batch.as_multi_agent(), reduce_fn=None
- )
- weights_after_1_update_with_break = new_learner_group.get_weights()
- new_learner_group.shutdown()
- del new_learner_group
- # construct a new learner group and load the initial state of the learner
- learner_group = get_learner_group(
- fw, env, scaling_config, eager_tracing=True
- )
- learner_group.load_state(initial_learner_checkpoint_dir)
- check(learner_group.get_weights(), initial_learner_group_weights)
- learner_group.update(batch.as_multi_agent(), reduce_fn=None)
- results_without_break = learner_group.update(
- batch.as_multi_agent(), reduce_fn=None
- )
- weights_after_1_update_without_break = learner_group.get_weights()
- learner_group.shutdown()
- del learner_group
- # compare the results of the two updates
- check(results_with_break, results_without_break)
- check(
- weights_after_1_update_with_break, weights_after_1_update_without_break
- )
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|