import gym import unittest import ray from ray import tune from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.models.catalog import ModelCatalog from ray.rllib.examples.models.neural_computer import DNCMemory from ray.rllib.utils.framework import try_import_torch torch, _ = try_import_torch() class TestDNC(unittest.TestCase): stop = { "episode_reward_mean": 100.0, "timesteps_total": 10000000, } @classmethod def setUpClass(cls) -> None: ray.init(num_cpus=4, ignore_reinit_error=True) @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_pack_unpack(self): d = DNCMemory( gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "") # Add batch dim packed_state = [m.unsqueeze(0) for m in d.get_initial_state()] [m.random_() for m in packed_state] original_packed = [m.clone() for m in packed_state] B, T = packed_state[0].shape[:2] unpacked = d.unpack_state(packed_state) packed = d.pack_state(*unpacked) self.assertTrue(len(packed) > 0) self.assertEqual(len(packed), len(original_packed)) for m_idx in range(len(packed)): self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx])) def test_dnc_learning(self): ModelCatalog.register_custom_model("dnc", DNCMemory) config = { "env": StatelessCartPole, "gamma": 0.99, "num_envs_per_worker": 5, "framework": "torch", "num_workers": 1, "num_cpus_per_worker": 2.0, "lr": 0.01, "entropy_coeff": 0.0005, "vf_loss_coeff": 1e-5, "model": { "custom_model": "dnc", "max_seq_len": 64, "custom_model_config": { "nr_cells": 10, "cell_size": 8, }, }, } tune.run("A2C", config=config, stop=self.stop, verbose=1) if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))