lstm_auto_wrapping.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import numpy as np
  2. import ray
  3. import ray.rllib.agents.ppo as ppo
  4. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  5. from ray.rllib.models.catalog import ModelCatalog
  6. from ray.rllib.utils.framework import try_import_torch
  7. torch, _ = try_import_torch()
  8. # __sphinx_doc_begin__
  9. # The custom model that will be wrapped by an LSTM.
  10. class MyCustomModel(TorchModelV2):
  11. def __init__(self, obs_space, action_space, num_outputs, model_config,
  12. name):
  13. super().__init__(obs_space, action_space, num_outputs, model_config,
  14. name)
  15. self.num_outputs = int(np.product(self.obs_space.shape))
  16. self._last_batch_size = None
  17. # Implement your own forward logic, whose output will then be sent
  18. # through an LSTM.
  19. def forward(self, input_dict, state, seq_lens):
  20. obs = input_dict["obs_flat"]
  21. # Store last batch size for value_function output.
  22. self._last_batch_size = obs.shape[0]
  23. # Return 2x the obs (and empty states).
  24. # This will further be sent through an automatically provided
  25. # LSTM head (b/c we are setting use_lstm=True below).
  26. return obs * 2.0, []
  27. def value_function(self):
  28. return torch.from_numpy(np.zeros(shape=(self._last_batch_size, )))
  29. if __name__ == "__main__":
  30. ray.init()
  31. # Register the above custom model.
  32. ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)
  33. # Create the Trainer.
  34. trainer = ppo.PPOTrainer(
  35. env="CartPole-v0",
  36. config={
  37. "framework": "torch",
  38. "model": {
  39. # Auto-wrap the custom(!) model with an LSTM.
  40. "use_lstm": True,
  41. # To further customize the LSTM auto-wrapper.
  42. "lstm_cell_size": 64,
  43. # Specify our custom model from above.
  44. "custom_model": "my_torch_model",
  45. # Extra kwargs to be passed to your model's c'tor.
  46. "custom_model_config": {},
  47. },
  48. })
  49. trainer.train()
  50. # __sphinx_doc_end__