custom_model_api.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import argparse
  2. from gym.spaces import Box, Discrete
  3. import numpy as np
  4. from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
  5. TorchDuelingQModel, ContActionQModel, TorchContActionQModel
  6. from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. tf1, tf, tfv = try_import_tf()
  10. torch, _ = try_import_torch()
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. "--framework",
  14. choices=["tf", "tf2", "tfe", "torch"],
  15. default="tf",
  16. help="The DL framework specifier.")
  17. if __name__ == "__main__":
  18. args = parser.parse_args()
  19. # Test API wrapper for dueling Q-head.
  20. obs_space = Box(-1.0, 1.0, (3, ))
  21. action_space = Discrete(3)
  22. # Run in eager mode for value checking and debugging.
  23. tf1.enable_eager_execution()
  24. # __sphinx_doc_model_construct_1_begin__
  25. my_dueling_model = ModelCatalog.get_model_v2(
  26. obs_space=obs_space,
  27. action_space=action_space,
  28. num_outputs=action_space.n,
  29. model_config=MODEL_DEFAULTS,
  30. framework=args.framework,
  31. # Providing the `model_interface` arg will make the factory
  32. # wrap the chosen default model with our new model API class
  33. # (DuelingQModel). This way, both `forward` and `get_q_values`
  34. # are available in the returned class.
  35. model_interface=DuelingQModel
  36. if args.framework != "torch" else TorchDuelingQModel,
  37. name="dueling_q_model",
  38. )
  39. # __sphinx_doc_model_construct_1_end__
  40. batch_size = 10
  41. input_ = np.array([obs_space.sample() for _ in range(batch_size)])
  42. # Note that for PyTorch, you will have to provide torch tensors here.
  43. if args.framework == "torch":
  44. input_ = torch.from_numpy(input_)
  45. input_dict = SampleBatch(obs=input_, _is_training=False)
  46. out, state_outs = my_dueling_model(input_dict=input_dict)
  47. assert out.shape == (10, 256)
  48. # Pass `out` into `get_q_values`
  49. q_values = my_dueling_model.get_q_values(out)
  50. assert q_values.shape == (10, action_space.n)
  51. # Test API wrapper for single value Q-head from obs/action input.
  52. obs_space = Box(-1.0, 1.0, (3, ))
  53. action_space = Box(-1.0, -1.0, (2, ))
  54. # __sphinx_doc_model_construct_2_begin__
  55. my_cont_action_q_model = ModelCatalog.get_model_v2(
  56. obs_space=obs_space,
  57. action_space=action_space,
  58. num_outputs=2,
  59. model_config=MODEL_DEFAULTS,
  60. framework=args.framework,
  61. # Providing the `model_interface` arg will make the factory
  62. # wrap the chosen default model with our new model API class
  63. # (DuelingQModel). This way, both `forward` and `get_q_values`
  64. # are available in the returned class.
  65. model_interface=ContActionQModel
  66. if args.framework != "torch" else TorchContActionQModel,
  67. name="cont_action_q_model",
  68. )
  69. # __sphinx_doc_model_construct_2_end__
  70. batch_size = 10
  71. input_ = np.array([obs_space.sample() for _ in range(batch_size)])
  72. # Note that for PyTorch, you will have to provide torch tensors here.
  73. if args.framework == "torch":
  74. input_ = torch.from_numpy(input_)
  75. input_dict = SampleBatch(obs=input_, _is_training=False)
  76. # Note that for PyTorch, you will have to provide torch tensors here.
  77. out, state_outs = my_cont_action_q_model(input_dict=input_dict)
  78. assert out.shape == (10, 256)
  79. # Pass `out` and an action into `my_cont_action_q_model`
  80. action = np.array([action_space.sample() for _ in range(batch_size)])
  81. if args.framework == "torch":
  82. action = torch.from_numpy(action)
  83. q_value = my_cont_action_q_model.get_single_q_value(out, action)
  84. assert q_value.shape == (10, 1)