offline_rl.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """Example on how to use a CQLTrainer to learn from an offline json file.
  2. Important node: Make sure that your offline data file contains only
  3. a single timestep per line to mimic the way SAC pulls samples from
  4. the buffer.
  5. Generate the offline json file by running an SAC algo until it reaches expert
  6. level on your command line. For example:
  7. $ cd ray
  8. $ rllib train -f rllib/tuned_examples/sac/pendulum-sac.yaml --no-ray-ui
  9. Also make sure that in the above SAC yaml file (pendulum-sac.yaml),
  10. you specify an additional "output" key with any path on your local
  11. file system. In that path, the offline json files will be written to.
  12. Use the generated file(s) as "input" in the CQL config below
  13. (`config["input"] = [list of your json files]`), then run this script.
  14. """
  15. import numpy as np
  16. import os
  17. from ray.rllib.agents import cql as cql
  18. from ray.rllib.utils.framework import try_import_torch
  19. torch, _ = try_import_torch()
  20. if __name__ == "__main__":
  21. # See rllib/tuned_examples/cql/pendulum-cql.yaml for comparison.
  22. config = cql.CQL_DEFAULT_CONFIG.copy()
  23. config["num_workers"] = 0 # Run locally.
  24. config["horizon"] = 200
  25. config["soft_horizon"] = True
  26. config["no_done_at_end"] = True
  27. config["n_step"] = 3
  28. config["bc_iters"] = 0
  29. config["clip_actions"] = False
  30. config["normalize_actions"] = True
  31. config["learning_starts"] = 256
  32. config["rollout_fragment_length"] = 1
  33. config["prioritized_replay"] = False
  34. config["tau"] = 0.005
  35. config["target_entropy"] = "auto"
  36. config["Q_model"] = {
  37. "fcnet_hiddens": [256, 256],
  38. "fcnet_activation": "relu",
  39. }
  40. config["policy_model"] = {
  41. "fcnet_hiddens": [256, 256],
  42. "fcnet_activation": "relu",
  43. }
  44. config["optimization"] = {
  45. "actor_learning_rate": 3e-4,
  46. "critic_learning_rate": 3e-4,
  47. "entropy_learning_rate": 3e-4,
  48. }
  49. config["train_batch_size"] = 256
  50. config["target_network_update_freq"] = 1
  51. config["timesteps_per_iteration"] = 1000
  52. data_file = "/path/to/my/json_file.json"
  53. print("data_file={} exists={}".format(data_file,
  54. os.path.isfile(data_file)))
  55. config["input"] = [data_file]
  56. config["log_level"] = "INFO"
  57. config["env"] = "Pendulum-v1"
  58. # Set up evaluation.
  59. config["evaluation_num_workers"] = 1
  60. config["evaluation_interval"] = 1
  61. config["evaluation_duration"] = 10
  62. # This should be False b/c iterations are very long and this would
  63. # cause evaluation to lag one iter behind training.
  64. config["evaluation_parallel_to_training"] = False
  65. # Evaluate on actual environment.
  66. config["evaluation_config"] = {"input": "sampler"}
  67. # Check, whether we can learn from the given file in `num_iterations`
  68. # iterations, up to a reward of `min_reward`.
  69. num_iterations = 5
  70. min_reward = -300
  71. # Test for torch framework (tf not implemented yet).
  72. trainer = cql.CQLTrainer(config=config)
  73. learnt = False
  74. for i in range(num_iterations):
  75. print(f"Iter {i}")
  76. eval_results = trainer.train().get("evaluation")
  77. if eval_results:
  78. print("... R={}".format(eval_results["episode_reward_mean"]))
  79. # Learn until some reward is reached on an actual live env.
  80. if eval_results["episode_reward_mean"] >= min_reward:
  81. learnt = True
  82. break
  83. if not learnt:
  84. raise ValueError("CQLTrainer did not reach {} reward from expert "
  85. "offline data!".format(min_reward))
  86. # Get policy, model, and replay-buffer.
  87. pol = trainer.get_policy()
  88. cql_model = pol.model
  89. from ray.rllib.agents.cql.cql import replay_buffer
  90. # If you would like to query CQL's learnt Q-function for arbitrary
  91. # (cont.) actions, do the following:
  92. obs_batch = torch.from_numpy(np.random.random(size=(5, 3)))
  93. action_batch = torch.from_numpy(np.random.random(size=(5, 1)))
  94. q_values = cql_model.get_q_values(obs_batch, action_batch)
  95. # If you are using the "twin_q", there'll be 2 Q-networks and
  96. # we usually consider the min of the 2 outputs, like so:
  97. twin_q_values = cql_model.get_twin_q_values(obs_batch, action_batch)
  98. final_q_values = torch.min(q_values, twin_q_values)
  99. print(final_q_values)
  100. # Example on how to do evaluation on the trained Trainer
  101. # using the data from our buffer.
  102. # Get a sample (MultiAgentBatch -> SampleBatch).
  103. batch = replay_buffer.replay().policy_batches["default_policy"]
  104. obs = torch.from_numpy(batch["obs"])
  105. # Pass the observations through our model to get the
  106. # features, which then to pass through the Q-head.
  107. model_out, _ = cql_model({"obs": obs})
  108. # The estimated Q-values from the (historic) actions in the batch.
  109. q_values_old = cql_model.get_q_values(model_out,
  110. torch.from_numpy(batch["actions"]))
  111. # The estimated Q-values for the new actions computed
  112. # by our trainer policy.
  113. actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
  114. q_values_new = cql_model.get_q_values(model_out,
  115. torch.from_numpy(actions_new))
  116. print(f"Q-val batch={q_values_old}")
  117. print(f"Q-val policy={q_values_new}")
  118. trainer.stop()