123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- from gym import spaces
- from gym.envs.registration import EnvSpec
- import gym
- import numpy as np
- import pickle
- import unittest
- import ray
- from ray.rllib.agents.a3c import A2CTrainer
- from ray.rllib.agents.pg import PGTrainer
- from ray.rllib.env import MultiAgentEnv
- from ray.rllib.env.base_env import convert_to_base_env
- from ray.rllib.env.vector_env import VectorEnv
- from ray.rllib.models import ModelCatalog
- from ray.rllib.models.tf.tf_modelv2 import TFModelV2
- from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.evaluate import rollout
- from ray.rllib.tests.test_external_env import SimpleServing
- from ray.tune.registry import register_env
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.numpy import one_hot
- from ray.rllib.utils.spaces.repeated import Repeated
- from ray.rllib.utils.test_utils import check
- tf1, tf, tfv = try_import_tf()
- _, nn = try_import_torch()
- DICT_SPACE = spaces.Dict({
- "sensors": spaces.Dict({
- "position": spaces.Box(low=-100, high=100, shape=(3, )),
- "velocity": spaces.Box(low=-1, high=1, shape=(3, )),
- "front_cam": spaces.Tuple(
- (spaces.Box(low=0, high=1, shape=(10, 10, 3)),
- spaces.Box(low=0, high=1, shape=(10, 10, 3)))),
- "rear_cam": spaces.Box(low=0, high=1, shape=(10, 10, 3)),
- }),
- "inner_state": spaces.Dict({
- "charge": spaces.Discrete(100),
- "job_status": spaces.Dict({
- "task": spaces.Discrete(5),
- "progress": spaces.Box(low=0, high=100, shape=()),
- })
- })
- })
- DICT_SAMPLES = [DICT_SPACE.sample() for _ in range(10)]
- TUPLE_SPACE = spaces.Tuple([
- spaces.Box(low=-100, high=100, shape=(3, )),
- spaces.Tuple((spaces.Box(low=0, high=1, shape=(10, 10, 3)),
- spaces.Box(low=0, high=1, shape=(10, 10, 3)))),
- spaces.Discrete(5),
- ])
- TUPLE_SAMPLES = [TUPLE_SPACE.sample() for _ in range(10)]
- # Constraints on the Repeated space.
- MAX_PLAYERS = 4
- MAX_ITEMS = 7
- MAX_EFFECTS = 2
- ITEM_SPACE = spaces.Box(-5, 5, shape=(1, ))
- EFFECT_SPACE = spaces.Box(9000, 9999, shape=(4, ))
- PLAYER_SPACE = spaces.Dict({
- "location": spaces.Box(-100, 100, shape=(2, )),
- "items": Repeated(ITEM_SPACE, max_len=MAX_ITEMS),
- "effects": Repeated(EFFECT_SPACE, max_len=MAX_EFFECTS),
- "status": spaces.Box(-1, 1, shape=(10, )),
- })
- REPEATED_SPACE = Repeated(PLAYER_SPACE, max_len=MAX_PLAYERS)
- REPEATED_SAMPLES = [REPEATED_SPACE.sample() for _ in range(10)]
- class NestedDictEnv(gym.Env):
- def __init__(self):
- self.action_space = spaces.Discrete(2)
- self.observation_space = DICT_SPACE
- self._spec = EnvSpec("NestedDictEnv-v0")
- self.steps = 0
- def reset(self):
- self.steps = 0
- return DICT_SAMPLES[0]
- def step(self, action):
- self.steps += 1
- return DICT_SAMPLES[self.steps], 1, self.steps >= 5, {}
- class NestedTupleEnv(gym.Env):
- def __init__(self):
- self.action_space = spaces.Discrete(2)
- self.observation_space = TUPLE_SPACE
- self._spec = EnvSpec("NestedTupleEnv-v0")
- self.steps = 0
- def reset(self):
- self.steps = 0
- return TUPLE_SAMPLES[0]
- def step(self, action):
- self.steps += 1
- return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {}
- class RepeatedSpaceEnv(gym.Env):
- def __init__(self):
- self.action_space = spaces.Discrete(2)
- self.observation_space = REPEATED_SPACE
- self._spec = EnvSpec("RepeatedSpaceEnv-v0")
- self.steps = 0
- def reset(self):
- self.steps = 0
- return REPEATED_SAMPLES[0]
- def step(self, action):
- self.steps += 1
- return REPEATED_SAMPLES[self.steps], 1, self.steps >= 5, {}
- class NestedMultiAgentEnv(MultiAgentEnv):
- def __init__(self):
- super().__init__()
- self.observation_space = spaces.Dict({
- "dict_agent": DICT_SPACE,
- "tuple_agent": TUPLE_SPACE
- })
- self.action_space = spaces.Dict({
- "dict_agent": spaces.Discrete(1),
- "tuple_agent": spaces.Discrete(1)
- })
- self._agent_ids = {"dict_agent", "tuple_agent"}
- self.steps = 0
- def reset(self):
- return {
- "dict_agent": DICT_SAMPLES[0],
- "tuple_agent": TUPLE_SAMPLES[0],
- }
- def step(self, actions):
- self.steps += 1
- obs = {
- "dict_agent": DICT_SAMPLES[self.steps],
- "tuple_agent": TUPLE_SAMPLES[self.steps],
- }
- rew = {
- "dict_agent": 0,
- "tuple_agent": 0,
- }
- dones = {"__all__": self.steps >= 5}
- infos = {
- "dict_agent": {},
- "tuple_agent": {},
- }
- return obs, rew, dones, infos
- class InvalidModel(TorchModelV2):
- def forward(self, input_dict, state, seq_lens):
- return "not", "valid"
- class InvalidModel2(TFModelV2):
- def forward(self, input_dict, state, seq_lens):
- return tf.constant(0), tf.constant(0)
- class TorchSpyModel(TorchModelV2, nn.Module):
- capture_index = 0
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
- model_config, name)
- nn.Module.__init__(self)
- self.fc = FullyConnectedNetwork(
- obs_space.original_space["sensors"].spaces["position"],
- action_space, num_outputs, model_config, name)
- def forward(self, input_dict, state, seq_lens):
- pos = input_dict["obs"]["sensors"]["position"].detach().cpu().numpy()
- front_cam = input_dict["obs"]["sensors"]["front_cam"][
- 0].detach().cpu().numpy()
- task = input_dict["obs"]["inner_state"]["job_status"][
- "task"].detach().cpu().numpy()
- ray.experimental.internal_kv._internal_kv_put(
- "torch_spy_in_{}".format(TorchSpyModel.capture_index),
- pickle.dumps((pos, front_cam, task)),
- overwrite=True)
- TorchSpyModel.capture_index += 1
- return self.fc({
- "obs": input_dict["obs"]["sensors"]["position"]
- }, state, seq_lens)
- def value_function(self):
- return self.fc.value_function()
- class TorchRepeatedSpyModel(TorchModelV2, nn.Module):
- capture_index = 0
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
- model_config, name)
- nn.Module.__init__(self)
- self.fc = FullyConnectedNetwork(
- obs_space.original_space.child_space["location"], action_space,
- num_outputs, model_config, name)
- def forward(self, input_dict, state, seq_lens):
- ray.experimental.internal_kv._internal_kv_put(
- "torch_rspy_in_{}".format(TorchRepeatedSpyModel.capture_index),
- pickle.dumps(input_dict["obs"].unbatch_all()),
- overwrite=True)
- TorchRepeatedSpyModel.capture_index += 1
- return self.fc({
- "obs": input_dict["obs"].values["location"][:, 0]
- }, state, seq_lens)
- def value_function(self):
- return self.fc.value_function()
- def to_list(value):
- if isinstance(value, list):
- return [to_list(x) for x in value]
- elif isinstance(value, dict):
- return {k: to_list(v) for k, v in value.items()}
- elif isinstance(value, np.ndarray):
- return value.tolist()
- elif isinstance(value, int):
- return value
- else:
- return value.detach().cpu().numpy().tolist()
- class DictSpyModel(TFModelV2):
- capture_index = 0
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super().__init__(obs_space, action_space, None, model_config, name)
- # Will only feed in sensors->pos.
- input_ = tf.keras.layers.Input(
- shape=self.obs_space["sensors"]["position"].shape)
- self.num_outputs = num_outputs or 64
- out = tf.keras.layers.Dense(self.num_outputs)(input_)
- self._main_layer = tf.keras.models.Model([input_], [out])
- def forward(self, input_dict, state, seq_lens):
- def spy(pos, front_cam, task):
- # TF runs this function in an isolated context, so we have to use
- # redis to communicate back to our suite
- ray.experimental.internal_kv._internal_kv_put(
- "d_spy_in_{}".format(DictSpyModel.capture_index),
- pickle.dumps((pos, front_cam, task)),
- overwrite=True)
- DictSpyModel.capture_index += 1
- return np.array(0, dtype=np.int64)
- spy_fn = tf1.py_func(
- spy, [
- input_dict["obs"]["sensors"]["position"],
- input_dict["obs"]["sensors"]["front_cam"][0],
- input_dict["obs"]["inner_state"]["job_status"]["task"]
- ],
- tf.int64,
- stateful=True)
- with tf1.control_dependencies([spy_fn]):
- output = self._main_layer(
- [input_dict["obs"]["sensors"]["position"]])
- return output, []
- class TupleSpyModel(TFModelV2):
- capture_index = 0
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super().__init__(obs_space, action_space, None, model_config, name)
- # Will only feed in 0th index of observation Tuple space.
- input_ = tf.keras.layers.Input(shape=self.obs_space[0].shape)
- self.num_outputs = num_outputs or 64
- out = tf.keras.layers.Dense(self.num_outputs)(input_)
- self._main_layer = tf.keras.models.Model([input_], [out])
- def forward(self, input_dict, state, seq_lens):
- def spy(pos, cam, task):
- # TF runs this function in an isolated context, so we have to use
- # redis to communicate back to our suite
- ray.experimental.internal_kv._internal_kv_put(
- "t_spy_in_{}".format(TupleSpyModel.capture_index),
- pickle.dumps((pos, cam, task)),
- overwrite=True)
- TupleSpyModel.capture_index += 1
- return np.array(0, dtype=np.int64)
- spy_fn = tf1.py_func(
- spy, [
- input_dict["obs"][0],
- input_dict["obs"][1][0],
- input_dict["obs"][2],
- ],
- tf.int64,
- stateful=True)
- with tf1.control_dependencies([spy_fn]):
- output = tf1.layers.dense(input_dict["obs"][0], self.num_outputs)
- return output, []
- class NestedObservationSpacesTest(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init(num_cpus=5)
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_invalid_model(self):
- ModelCatalog.register_custom_model("invalid", InvalidModel)
- self.assertRaisesRegex(
- ValueError,
- "Subclasses of TorchModelV2 must also inherit from nn.Module",
- lambda: PGTrainer(
- env="CartPole-v0",
- config={
- "model": {
- "custom_model": "invalid",
- },
- "framework": "torch",
- }))
- def test_invalid_model2(self):
- ModelCatalog.register_custom_model("invalid2", InvalidModel2)
- self.assertRaisesRegex(
- ValueError, "State output is not a list",
- lambda: PGTrainer(
- env="CartPole-v0", config={
- "model": {
- "custom_model": "invalid2",
- },
- "framework": "tf",
- }))
- def do_test_nested_dict(self, make_env, test_lstm=False):
- ModelCatalog.register_custom_model("composite", DictSpyModel)
- register_env("nested", make_env)
- pg = PGTrainer(
- env="nested",
- config={
- "num_workers": 0,
- "rollout_fragment_length": 5,
- "train_batch_size": 5,
- "model": {
- "custom_model": "composite",
- "use_lstm": test_lstm,
- },
- "framework": "tf",
- })
- # Skip first passes as they came from the TorchPolicy loss
- # initialization.
- DictSpyModel.capture_index = 0
- pg.train()
- # Check that the model sees the correct reconstructed observations
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "d_spy_in_{}".format(i)))
- pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
- cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
- task_i = DICT_SAMPLES[i]["inner_state"]["job_status"]["task"]
- self.assertEqual(seen[0][0].tolist(), pos_i)
- self.assertEqual(seen[1][0].tolist(), cam_i)
- check(seen[2][0], task_i)
- def do_test_nested_tuple(self, make_env):
- ModelCatalog.register_custom_model("composite2", TupleSpyModel)
- register_env("nested2", make_env)
- pg = PGTrainer(
- env="nested2",
- config={
- "num_workers": 0,
- "rollout_fragment_length": 5,
- "train_batch_size": 5,
- "model": {
- "custom_model": "composite2",
- },
- "framework": "tf",
- })
- # Skip first passes as they came from the TorchPolicy loss
- # initialization.
- TupleSpyModel.capture_index = 0
- pg.train()
- # Check that the model sees the correct reconstructed observations
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "t_spy_in_{}".format(i)))
- pos_i = TUPLE_SAMPLES[i][0].tolist()
- cam_i = TUPLE_SAMPLES[i][1][0].tolist()
- task_i = TUPLE_SAMPLES[i][2]
- self.assertEqual(seen[0][0].tolist(), pos_i)
- self.assertEqual(seen[1][0].tolist(), cam_i)
- check(seen[2][0], task_i)
- def test_nested_dict_gym(self):
- self.do_test_nested_dict(lambda _: NestedDictEnv())
- def test_nested_dict_gym_lstm(self):
- self.do_test_nested_dict(lambda _: NestedDictEnv(), test_lstm=True)
- def test_nested_dict_vector(self):
- self.do_test_nested_dict(
- lambda _: VectorEnv.vectorize_gym_envs(lambda i: NestedDictEnv()))
- def test_nested_dict_serving(self):
- self.do_test_nested_dict(lambda _: SimpleServing(NestedDictEnv()))
- def test_nested_dict_async(self):
- self.do_test_nested_dict(
- lambda _: convert_to_base_env(NestedDictEnv()))
- def test_nested_tuple_gym(self):
- self.do_test_nested_tuple(lambda _: NestedTupleEnv())
- def test_nested_tuple_vector(self):
- self.do_test_nested_tuple(
- lambda _: VectorEnv.vectorize_gym_envs(lambda i: NestedTupleEnv()))
- def test_nested_tuple_serving(self):
- self.do_test_nested_tuple(lambda _: SimpleServing(NestedTupleEnv()))
- def test_nested_tuple_async(self):
- self.do_test_nested_tuple(
- lambda _: convert_to_base_env(NestedTupleEnv()))
- def test_multi_agent_complex_spaces(self):
- ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
- ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
- register_env("nested_ma", lambda _: NestedMultiAgentEnv())
- act_space = spaces.Discrete(2)
- pg = PGTrainer(
- env="nested_ma",
- config={
- "num_workers": 0,
- "rollout_fragment_length": 5,
- "train_batch_size": 5,
- "multiagent": {
- "policies": {
- "tuple_policy": (
- None, TUPLE_SPACE, act_space,
- {"model": {"custom_model": "tuple_spy"}}),
- "dict_policy": (
- None, DICT_SPACE, act_space,
- {"model": {"custom_model": "dict_spy"}}),
- },
- "policy_mapping_fn": lambda aid, **kwargs: {
- "tuple_agent": "tuple_policy",
- "dict_agent": "dict_policy"}[aid],
- },
- "framework": "tf",
- })
- # Skip first passes as they came from the TorchPolicy loss
- # initialization.
- TupleSpyModel.capture_index = DictSpyModel.capture_index = 0
- pg.train()
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "d_spy_in_{}".format(i)))
- pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
- cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
- task_i = DICT_SAMPLES[i]["inner_state"]["job_status"]["task"]
- self.assertEqual(seen[0][0].tolist(), pos_i)
- self.assertEqual(seen[1][0].tolist(), cam_i)
- check(seen[2][0], task_i)
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "t_spy_in_{}".format(i)))
- pos_i = TUPLE_SAMPLES[i][0].tolist()
- cam_i = TUPLE_SAMPLES[i][1][0].tolist()
- task_i = TUPLE_SAMPLES[i][2]
- self.assertEqual(seen[0][0].tolist(), pos_i)
- self.assertEqual(seen[1][0].tolist(), cam_i)
- check(seen[2][0], task_i)
- def test_rollout_dict_space(self):
- register_env("nested", lambda _: NestedDictEnv())
- agent = PGTrainer(env="nested", config={"framework": "tf"})
- agent.train()
- path = agent.save()
- agent.stop()
- # Test train works on restore
- agent2 = PGTrainer(env="nested", config={"framework": "tf"})
- agent2.restore(path)
- agent2.train()
- # Test rollout works on restore
- rollout(agent2, "nested", 100)
- def test_py_torch_model(self):
- ModelCatalog.register_custom_model("composite", TorchSpyModel)
- register_env("nested", lambda _: NestedDictEnv())
- a2c = A2CTrainer(
- env="nested",
- config={
- "num_workers": 0,
- "rollout_fragment_length": 5,
- "train_batch_size": 5,
- "model": {
- "custom_model": "composite",
- },
- "framework": "torch",
- })
- # Skip first passes as they came from the TorchPolicy loss
- # initialization.
- TorchSpyModel.capture_index = 0
- a2c.train()
- # Check that the model sees the correct reconstructed observations
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "torch_spy_in_{}".format(i)))
- pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
- cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
- task_i = one_hot(
- DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
- # Only look at the last entry (-1) in `seen` as we reset (re-use)
- # the ray-kv indices before training.
- self.assertEqual(seen[0][-1].tolist(), pos_i)
- self.assertEqual(seen[1][-1].tolist(), cam_i)
- check(seen[2][-1], task_i)
- # TODO(ekl) should probably also add a test for TF/eager
- def test_torch_repeated(self):
- ModelCatalog.register_custom_model("r1", TorchRepeatedSpyModel)
- register_env("repeat", lambda _: RepeatedSpaceEnv())
- a2c = A2CTrainer(
- env="repeat",
- config={
- "num_workers": 0,
- "rollout_fragment_length": 5,
- "train_batch_size": 5,
- "model": {
- "custom_model": "r1",
- },
- "framework": "torch",
- })
- # Skip first passes as they came from the TorchPolicy loss
- # initialization.
- TorchRepeatedSpyModel.capture_index = 0
- a2c.train()
- # Check that the model sees the correct reconstructed observations
- for i in range(4):
- seen = pickle.loads(
- ray.experimental.internal_kv._internal_kv_get(
- "torch_rspy_in_{}".format(i)))
- # Only look at the last entry (-1) in `seen` as we reset (re-use)
- # the ray-kv indices before training.
- self.assertEqual(
- to_list(seen[:][-1]), to_list(REPEATED_SAMPLES[i]))
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|