12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import gym
- import unittest
- import ray
- from ray import tune
- from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.examples.models.neural_computer import DNCMemory
- from ray.rllib.utils.framework import try_import_torch
- torch, _ = try_import_torch()
- class TestDNC(unittest.TestCase):
- stop = {
- "episode_reward_mean": 100.0,
- "timesteps_total": 10000000,
- }
- @classmethod
- def setUpClass(cls) -> None:
- ray.init(num_cpus=4, ignore_reinit_error=True)
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_pack_unpack(self):
- d = DNCMemory(
- gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
- # Add batch dim
- packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
- [m.random_() for m in packed_state]
- original_packed = [m.clone() for m in packed_state]
- B, T = packed_state[0].shape[:2]
- unpacked = d.unpack_state(packed_state)
- packed = d.pack_state(*unpacked)
- self.assertTrue(len(packed) > 0)
- self.assertEqual(len(packed), len(original_packed))
- for m_idx in range(len(packed)):
- self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
- def test_dnc_learning(self):
- ModelCatalog.register_custom_model("dnc", DNCMemory)
- config = {
- "env": StatelessCartPole,
- "gamma": 0.99,
- "num_envs_per_worker": 5,
- "framework": "torch",
- "num_workers": 1,
- "num_cpus_per_worker": 2.0,
- "lr": 0.01,
- "entropy_coeff": 0.0005,
- "vf_loss_coeff": 1e-5,
- "model": {
- "custom_model": "dnc",
- "max_seq_len": 64,
- "custom_model_config": {
- "nr_cells": 10,
- "cell_size": 8,
- },
- },
- }
- tune.run("A2C", config=config, stop=self.stop, verbose=1)
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|