123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import argparse
- from gym.spaces import Box, Discrete
- import numpy as np
- from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
- TorchDuelingQModel, ContActionQModel, TorchContActionQModel
- from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--framework",
- choices=["tf", "tf2", "tfe", "torch"],
- default="tf",
- help="The DL framework specifier.")
- if __name__ == "__main__":
- args = parser.parse_args()
- # Test API wrapper for dueling Q-head.
- obs_space = Box(-1.0, 1.0, (3, ))
- action_space = Discrete(3)
- # Run in eager mode for value checking and debugging.
- tf1.enable_eager_execution()
- # __sphinx_doc_model_construct_1_begin__
- my_dueling_model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=action_space.n,
- model_config=MODEL_DEFAULTS,
- framework=args.framework,
- # Providing the `model_interface` arg will make the factory
- # wrap the chosen default model with our new model API class
- # (DuelingQModel). This way, both `forward` and `get_q_values`
- # are available in the returned class.
- model_interface=DuelingQModel
- if args.framework != "torch" else TorchDuelingQModel,
- name="dueling_q_model",
- )
- # __sphinx_doc_model_construct_1_end__
- batch_size = 10
- input_ = np.array([obs_space.sample() for _ in range(batch_size)])
- # Note that for PyTorch, you will have to provide torch tensors here.
- if args.framework == "torch":
- input_ = torch.from_numpy(input_)
- input_dict = SampleBatch(obs=input_, _is_training=False)
- out, state_outs = my_dueling_model(input_dict=input_dict)
- assert out.shape == (10, 256)
- # Pass `out` into `get_q_values`
- q_values = my_dueling_model.get_q_values(out)
- assert q_values.shape == (10, action_space.n)
- # Test API wrapper for single value Q-head from obs/action input.
- obs_space = Box(-1.0, 1.0, (3, ))
- action_space = Box(-1.0, -1.0, (2, ))
- # __sphinx_doc_model_construct_2_begin__
- my_cont_action_q_model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=2,
- model_config=MODEL_DEFAULTS,
- framework=args.framework,
- # Providing the `model_interface` arg will make the factory
- # wrap the chosen default model with our new model API class
- # (DuelingQModel). This way, both `forward` and `get_q_values`
- # are available in the returned class.
- model_interface=ContActionQModel
- if args.framework != "torch" else TorchContActionQModel,
- name="cont_action_q_model",
- )
- # __sphinx_doc_model_construct_2_end__
- batch_size = 10
- input_ = np.array([obs_space.sample() for _ in range(batch_size)])
- # Note that for PyTorch, you will have to provide torch tensors here.
- if args.framework == "torch":
- input_ = torch.from_numpy(input_)
- input_dict = SampleBatch(obs=input_, _is_training=False)
- # Note that for PyTorch, you will have to provide torch tensors here.
- out, state_outs = my_cont_action_q_model(input_dict=input_dict)
- assert out.shape == (10, 256)
- # Pass `out` and an action into `my_cont_action_q_model`
- action = np.array([action_space.sample() for _ in range(batch_size)])
- if args.framework == "torch":
- action = torch.from_numpy(action)
- q_value = my_cont_action_q_model.get_single_q_value(out, action)
- assert q_value.shape == (10, 1)
|