training.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # flake8: noqa
  2. # __preprocessing_observations_start__
  3. try:
  4. import gymnasium as gym
  5. env = gym.make("GymV26Environment-v0", env_id="ALE/Pong-v5")
  6. obs, infos = env.reset()
  7. except Exception:
  8. import gym
  9. env = gym.make("PongNoFrameskip-v4")
  10. obs = env.reset()
  11. # RLlib uses preprocessors to implement transforms such as one-hot encoding
  12. # and flattening of tuple and dict observations.
  13. from ray.rllib.models.preprocessors import get_preprocessor
  14. prep = get_preprocessor(env.observation_space)(env.observation_space)
  15. # <ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
  16. # Observations should be preprocessed prior to feeding into a model
  17. obs.shape
  18. # (210, 160, 3)
  19. prep.transform(obs).shape
  20. # (84, 84, 3)
  21. # __preprocessing_observations_end__
  22. # __query_action_dist_start__
  23. # Get a reference to the policy
  24. import numpy as np
  25. from ray.rllib.algorithms.dqn import DQNConfig
  26. algo = (
  27. DQNConfig()
  28. .environment("CartPole-v1")
  29. .framework("tf2")
  30. .rollouts(num_rollout_workers=0)
  31. .build()
  32. )
  33. # <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>
  34. policy = algo.get_policy()
  35. # <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
  36. # Run a forward pass to get model output logits. Note that complex observations
  37. # must be preprocessed as in the above code block.
  38. logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
  39. # (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
  40. # Compute action distribution given logits
  41. policy.dist_class
  42. # <class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
  43. dist = policy.dist_class(logits, policy.model)
  44. # <ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
  45. # Query the distribution for samples, sample logps
  46. dist.sample()
  47. # <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
  48. dist.logp([1])
  49. # <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
  50. # Get the estimated values for the most recent forward pass
  51. policy.model.value_function()
  52. # <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
  53. policy.model.base_model.summary()
  54. """
  55. Model: "model"
  56. _____________________________________________________________________
  57. Layer (type) Output Shape Param # Connected to
  58. =====================================================================
  59. observations (InputLayer) [(None, 4)] 0
  60. _____________________________________________________________________
  61. fc_1 (Dense) (None, 256) 1280 observations[0][0]
  62. _____________________________________________________________________
  63. fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
  64. _____________________________________________________________________
  65. fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
  66. _____________________________________________________________________
  67. fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
  68. _____________________________________________________________________
  69. fc_out (Dense) (None, 2) 514 fc_2[0][0]
  70. _____________________________________________________________________
  71. value_out (Dense) (None, 1) 257 fc_value_2[0][0]
  72. =====================================================================
  73. Total params: 134,915
  74. Trainable params: 134,915
  75. Non-trainable params: 0
  76. _____________________________________________________________________
  77. """
  78. # __query_action_dist_end__
  79. # __get_q_values_dqn_start__
  80. # Get a reference to the model through the policy
  81. import numpy as np
  82. from ray.rllib.algorithms.dqn import DQNConfig
  83. algo = DQNConfig().environment("CartPole-v1").framework("tf2").build()
  84. model = algo.get_policy().model
  85. # <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
  86. # List of all model variables
  87. model.variables()
  88. # Run a forward pass to get base model output. Note that complex observations
  89. # must be preprocessed. An example of preprocessing is examples/saving_experiences.py
  90. model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
  91. # (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
  92. # Access the base Keras models (all default models have a base)
  93. model.base_model.summary()
  94. """
  95. Model: "model"
  96. _______________________________________________________________________
  97. Layer (type) Output Shape Param # Connected to
  98. =======================================================================
  99. observations (InputLayer) [(None, 4)] 0
  100. _______________________________________________________________________
  101. fc_1 (Dense) (None, 256) 1280 observations[0][0]
  102. _______________________________________________________________________
  103. fc_out (Dense) (None, 256) 65792 fc_1[0][0]
  104. _______________________________________________________________________
  105. value_out (Dense) (None, 1) 257 fc_1[0][0]
  106. =======================================================================
  107. Total params: 67,329
  108. Trainable params: 67,329
  109. Non-trainable params: 0
  110. ______________________________________________________________________________
  111. """
  112. # Access the Q value model (specific to DQN)
  113. print(model.get_q_value_distributions(model_out)[0])
  114. # tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
  115. # ^ exact numbers may differ due to randomness
  116. model.q_value_head.summary()
  117. # Access the state value model (specific to DQN)
  118. print(model.get_state_value(model_out))
  119. # tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
  120. # ^ exact number may differ due to randomness
  121. model.state_value_head.summary()
  122. # __get_q_values_dqn_end__