custom_model_api.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import argparse
  2. from gym.spaces import Box, Discrete
  3. import numpy as np
  4. from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
  5. TorchDuelingQModel, ContActionQModel, TorchContActionQModel
  6. from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. tf1, tf, tfv = try_import_tf()
  9. torch, _ = try_import_torch()
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument(
  12. "--framework",
  13. choices=["tf", "tf2", "tfe", "torch"],
  14. default="tf",
  15. help="The DL framework specifier.")
  16. if __name__ == "__main__":
  17. args = parser.parse_args()
  18. # Test API wrapper for dueling Q-head.
  19. obs_space = Box(-1.0, 1.0, (3, ))
  20. action_space = Discrete(3)
  21. # Run in eager mode for value checking and debugging.
  22. tf1.enable_eager_execution()
  23. # __sphinx_doc_model_construct_1_begin__
  24. my_dueling_model = ModelCatalog.get_model_v2(
  25. obs_space=obs_space,
  26. action_space=action_space,
  27. num_outputs=action_space.n,
  28. model_config=MODEL_DEFAULTS,
  29. framework=args.framework,
  30. # Providing the `model_interface` arg will make the factory
  31. # wrap the chosen default model with our new model API class
  32. # (DuelingQModel). This way, both `forward` and `get_q_values`
  33. # are available in the returned class.
  34. model_interface=DuelingQModel
  35. if args.framework != "torch" else TorchDuelingQModel,
  36. name="dueling_q_model",
  37. )
  38. # __sphinx_doc_model_construct_1_end__
  39. batch_size = 10
  40. input_ = np.array([obs_space.sample() for _ in range(batch_size)])
  41. # Note that for PyTorch, you will have to provide torch tensors here.
  42. if args.framework == "torch":
  43. input_ = torch.from_numpy(input_)
  44. input_dict = {
  45. "obs": input_,
  46. "is_training": False,
  47. }
  48. out, state_outs = my_dueling_model(input_dict=input_dict)
  49. assert out.shape == (10, 256)
  50. # Pass `out` into `get_q_values`
  51. q_values = my_dueling_model.get_q_values(out)
  52. assert q_values.shape == (10, action_space.n)
  53. # Test API wrapper for single value Q-head from obs/action input.
  54. obs_space = Box(-1.0, 1.0, (3, ))
  55. action_space = Box(-1.0, -1.0, (2, ))
  56. # __sphinx_doc_model_construct_2_begin__
  57. my_cont_action_q_model = ModelCatalog.get_model_v2(
  58. obs_space=obs_space,
  59. action_space=action_space,
  60. num_outputs=2,
  61. model_config=MODEL_DEFAULTS,
  62. framework=args.framework,
  63. # Providing the `model_interface` arg will make the factory
  64. # wrap the chosen default model with our new model API class
  65. # (DuelingQModel). This way, both `forward` and `get_q_values`
  66. # are available in the returned class.
  67. model_interface=ContActionQModel
  68. if args.framework != "torch" else TorchContActionQModel,
  69. name="cont_action_q_model",
  70. )
  71. # __sphinx_doc_model_construct_2_end__
  72. batch_size = 10
  73. input_ = np.array([obs_space.sample() for _ in range(batch_size)])
  74. # Note that for PyTorch, you will have to provide torch tensors here.
  75. if args.framework == "torch":
  76. input_ = torch.from_numpy(input_)
  77. input_dict = {
  78. "obs": input_,
  79. "is_training": False,
  80. }
  81. # Note that for PyTorch, you will have to provide torch tensors here.
  82. out, state_outs = my_cont_action_q_model(input_dict=input_dict)
  83. assert out.shape == (10, 256)
  84. # Pass `out` and an action into `my_cont_action_q_model`
  85. action = np.array([action_space.sample() for _ in range(batch_size)])
  86. if args.framework == "torch":
  87. action = torch.from_numpy(action)
  88. q_value = my_cont_action_q_model.get_single_q_value(out, action)
  89. assert q_value.shape == (10, 1)