saving_and_loading_algos_and_policies.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # flake8: noqa
  2. # __create-algo-checkpoint-begin__
  3. # Create a PPO algorithm object using a config object ..
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. my_ppo_config = PPOConfig().environment("CartPole-v1")
  6. my_ppo = my_ppo_config.build()
  7. # .. train one iteration ..
  8. my_ppo.train()
  9. # .. and call `save()` to create a checkpoint.
  10. save_result = my_ppo.save()
  11. path_to_checkpoint = save_result.checkpoint.path
  12. print(
  13. "An Algorithm checkpoint has been created inside directory: "
  14. f"'{path_to_checkpoint}'."
  15. )
  16. # Let's terminate the algo for demonstration purposes.
  17. my_ppo.stop()
  18. # Doing this will lead to an error.
  19. # my_ppo.train()
  20. # __create-algo-checkpoint-end__
  21. # __restore-from-algo-checkpoint-begin__
  22. from ray.rllib.algorithms.algorithm import Algorithm
  23. # Use the Algorithm's `from_checkpoint` utility to get a new algo instance
  24. # that has the exact same state as the old one, from which the checkpoint was
  25. # created in the first place:
  26. my_new_ppo = Algorithm.from_checkpoint(path_to_checkpoint)
  27. # Continue training.
  28. my_new_ppo.train()
  29. # __restore-from-algo-checkpoint-end__
  30. my_new_ppo.stop()
  31. # __restore-from-algo-checkpoint-2-begin__
  32. # Re-build a fresh algorithm.
  33. my_new_ppo = my_ppo_config.build()
  34. # Restore the old (checkpointed) state.
  35. my_new_ppo.restore(save_result)
  36. # Continue training.
  37. my_new_ppo.train()
  38. # __restore-from-algo-checkpoint-2-end__
  39. my_new_ppo.stop()
  40. # __multi-agent-checkpoints-begin__
  41. import os
  42. # Use our example multi-agent CartPole environment to train in.
  43. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  44. # Set up a multi-agent Algorithm, training two policies independently.
  45. my_ma_config = PPOConfig().multi_agent(
  46. # Which policies should RLlib create and train?
  47. policies={"pol1", "pol2"},
  48. # Let RLlib know, which agents in the environment (we'll have "agent1"
  49. # and "agent2") map to which policies.
  50. policy_mapping_fn=(
  51. lambda agent_id, episode, worker, **kw: (
  52. "pol1" if agent_id == "agent1" else "pol2"
  53. )
  54. ),
  55. # Setting these is not necessary. All policies will always be trained by default.
  56. # However, since we do provide a list of IDs here, we need to remain in charge of
  57. # changing this `policies_to_train` list, should we ever alter the Algorithm
  58. # (e.g. remove one of the policies or add a new one).
  59. policies_to_train=["pol1", "pol2"], # Again, `None` would be totally fine here.
  60. )
  61. # Add the MultiAgentCartPole env to our config and build our Algorithm.
  62. my_ma_config.environment(
  63. MultiAgentCartPole,
  64. env_config={
  65. "num_agents": 2,
  66. },
  67. )
  68. my_ma_algo = my_ma_config.build()
  69. my_ma_algo.train()
  70. ma_checkpoint_dir = my_ma_algo.save().checkpoint.path
  71. print(
  72. "An Algorithm checkpoint has been created inside directory: "
  73. f"'{ma_checkpoint_dir}'.\n"
  74. "Individual Policy checkpoints can be found in "
  75. f"'{os.path.join(ma_checkpoint_dir, 'policies')}'."
  76. )
  77. # Create a new Algorithm instance from the above checkpoint, just as you would for
  78. # a single-agent setup:
  79. my_ma_algo_clone = Algorithm.from_checkpoint(ma_checkpoint_dir)
  80. # __multi-agent-checkpoints-end__
  81. my_ma_algo_clone.stop()
  82. # __multi-agent-checkpoints-restore-policy-sub-set-begin__
  83. # Here, we use the same (multi-agent Algorithm) checkpoint as above, but only restore
  84. # it with the first Policy ("pol1").
  85. my_ma_algo_only_pol1 = Algorithm.from_checkpoint(
  86. ma_checkpoint_dir,
  87. # Tell the `from_checkpoint` util to create a new Algo, but only with "pol1" in it.
  88. policy_ids=["pol1"],
  89. # Make sure to update the mapping function (we must not map to "pol2" anymore
  90. # to avoid a runtime error). Now both agents ("agent0" and "agent1") map to
  91. # the same policy.
  92. policy_mapping_fn=lambda agent_id, episode, worker, **kw: "pol1",
  93. # Since we defined this above, we have to re-define it here with the updated
  94. # PolicyIDs, otherwise, RLlib will throw an error (it will think that there is an
  95. # unknown PolicyID in this list ("pol2")).
  96. policies_to_train=["pol1"],
  97. )
  98. # Make sure, pol2 is NOT in this Algorithm anymore.
  99. assert my_ma_algo_only_pol1.get_policy("pol2") is None
  100. # Continue training (only with pol1).
  101. my_ma_algo_only_pol1.train()
  102. # __multi-agent-checkpoints-restore-policy-sub-set-end__
  103. my_ma_algo_only_pol1.stop()
  104. # __create-policy-checkpoint-begin__
  105. # Retrieve the Policy object from an Algorithm.
  106. # Note that for normal, single-agent Algorithms, the Policy ID is "default_policy".
  107. policy1 = my_ma_algo.get_policy(policy_id="pol1")
  108. # Tell RLlib to store an individual policy checkpoint (only for "pol1") inside
  109. # /tmp/my_policy_checkpoint
  110. policy1.export_checkpoint("/tmp/my_policy_checkpoint")
  111. # __create-policy-checkpoint-end__
  112. # __restore-policy-begin__
  113. import numpy as np
  114. from ray.rllib.policy.policy import Policy
  115. # Use the `from_checkpoint` utility of the Policy class:
  116. my_restored_policy = Policy.from_checkpoint("/tmp/my_policy_checkpoint")
  117. # Use the restored policy for serving actions.
  118. obs = np.array([0.0, 0.1, 0.2, 0.3]) # individual CartPole observation
  119. action = my_restored_policy.compute_single_action(obs)
  120. print(f"Computed action {action} from given CartPole observation.")
  121. # __restore-policy-end__
  122. # __restore-algorithm-from-checkpoint-with-fewer-policies-begin__
  123. from ray.rllib.algorithms.ppo import PPOConfig
  124. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  125. # Set up an Algorithm with 5 Policies.
  126. algo_w_5_policies = (
  127. PPOConfig()
  128. .environment(
  129. env=MultiAgentCartPole,
  130. env_config={
  131. "num_agents": 5,
  132. },
  133. )
  134. .multi_agent(
  135. policies={"pol0", "pol1", "pol2", "pol3", "pol4"},
  136. # Map "agent0" -> "pol0", etc...
  137. policy_mapping_fn=(
  138. lambda agent_id, episode, worker, **kwargs: f"pol{agent_id}"
  139. ),
  140. )
  141. .build()
  142. )
  143. # .. train one iteration ..
  144. algo_w_5_policies.train()
  145. # .. and call `save()` to create a checkpoint.
  146. path_to_checkpoint = algo_w_5_policies.save().checkpoint.path
  147. print(
  148. "An Algorithm checkpoint has been created inside directory: "
  149. f"'{path_to_checkpoint}'. It should contain 5 policies in the 'policies/' sub dir."
  150. )
  151. # Let's terminate the algo for demonstration purposes.
  152. algo_w_5_policies.stop()
  153. # We will now recreate a new algo from this checkpoint, but only with 2 of the
  154. # original policies ("pol0" and "pol1"). Note that this will require us to change the
  155. # `policy_mapping_fn` (instead of mapping 5 agents to 5 policies, we now have
  156. # to map 5 agents to only 2 policies).
  157. def new_policy_mapping_fn(agent_id, episode, worker, **kwargs):
  158. return "pol0" if agent_id in ["agent0", "agent1"] else "pol1"
  159. algo_w_2_policies = Algorithm.from_checkpoint(
  160. checkpoint=path_to_checkpoint,
  161. policy_ids={"pol0", "pol1"}, # <- restore only those policy IDs here.
  162. policy_mapping_fn=new_policy_mapping_fn, # <- use this new mapping fn.
  163. )
  164. # Test, whether we can train with this new setup.
  165. algo_w_2_policies.train()
  166. # Terminate the new algo.
  167. algo_w_2_policies.stop()
  168. # __restore-algorithm-from-checkpoint-with-fewer-policies-end__
  169. # __export-models-begin__
  170. from ray.rllib.algorithms.ppo import PPOConfig
  171. # Create a new Algorithm (which contains a Policy, which contains a NN Model).
  172. # Switch on for native models to be included in the Policy checkpoints.
  173. ppo_config = (
  174. PPOConfig().environment("Pendulum-v1").checkpointing(export_native_model_files=True)
  175. )
  176. # The default framework is TensorFlow, but if you would like to do this example with
  177. # PyTorch, uncomment the following line of code:
  178. # ppo_config.framework("torch")
  179. # Create the Algorithm and train one iteration.
  180. ppo = ppo_config.build()
  181. ppo.train()
  182. # Get the underlying PPOTF1Policy (or PPOTorchPolicy) object.
  183. ppo_policy = ppo.get_policy()
  184. # __export-models-end__
  185. # Export the Keras NN model (that our PPOTF1Policy inside the PPO Algorithm uses)
  186. # to disk ...
  187. # 1) .. using the Policy object:
  188. # __export-models-1-begin__
  189. ppo_policy.export_model("/tmp/my_nn_model")
  190. # .. check /tmp/my_nn_model/ for the model files.
  191. # For Keras You should be able to recover the model via:
  192. # keras_model = tf.saved_model.load("/tmp/my_nn_model/")
  193. # And pass in a Pendulum-v1 observation:
  194. # results = keras_model(tf.convert_to_tensor(
  195. # np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
  196. # )
  197. # For PyTorch, do:
  198. # pytorch_model = torch.load("/tmp/my_nn_model/model.pt")
  199. # results = pytorch_model(
  200. # input_dict={
  201. # "obs": torch.from_numpy(np.array([[0.0, 0.1, 0.2]], dtype=np.float32)),
  202. # },
  203. # state=[torch.tensor(0)], # dummy value
  204. # seq_lens=torch.tensor(0), # dummy value
  205. # )
  206. # __export-models-1-end__
  207. # 2) .. via the Policy's checkpointing method:
  208. # __export-models-2-begin__
  209. checkpoint_dir = ppo_policy.export_checkpoint("tmp/ppo_policy")
  210. # .. check /tmp/ppo_policy/model/ for the model files.
  211. # You should be able to recover the keras model via:
  212. # keras_model = tf.saved_model.load("/tmp/ppo_policy/model")
  213. # And pass in a Pendulum-v1 observation:
  214. # results = keras_model(tf.convert_to_tensor(
  215. # np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
  216. # )
  217. # __export-models-2-end__
  218. # 3) .. via the Algorithm (Policy) checkpoint:
  219. # __export-models-3-begin__
  220. checkpoint_dir = ppo.save().checkpoint.path
  221. # .. check `checkpoint_dir` for the Algorithm checkpoint files.
  222. # For keras you should be able to recover the model via:
  223. # keras_model = tf.saved_model.load(checkpoint_dir + "/policies/default_policy/model/")
  224. # And pass in a Pendulum-v1 observation
  225. # results = keras_model(tf.convert_to_tensor(
  226. # np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
  227. # )
  228. # __export-models-3-end__
  229. # __export-models-as-onnx-begin__
  230. # Using the same Policy object, we can also export our NN Model in the ONNX format:
  231. ppo_policy.export_model("/tmp/my_nn_model", onnx=False)
  232. # __export-models-as-onnx-end__