saving_experiences.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. """Simple example of writing experiences to a file using JsonWriter."""
  2. # __sphinx_doc_begin__
  3. import gymnasium as gym
  4. import numpy as np
  5. import os
  6. import ray._private.utils
  7. from ray.rllib.models.preprocessors import get_preprocessor
  8. from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
  9. from ray.rllib.offline.json_writer import JsonWriter
  10. if __name__ == "__main__":
  11. batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder
  12. writer = JsonWriter(
  13. os.path.join(ray._private.utils.get_user_temp_dir(), "demo-out")
  14. )
  15. # You normally wouldn't want to manually create sample batches if a
  16. # simulator is available, but let's do it anyways for example purposes:
  17. env = gym.make("CartPole-v1")
  18. # RLlib uses preprocessors to implement transforms such as one-hot encoding
  19. # and flattening of tuple and dict observations. For CartPole a no-op
  20. # preprocessor is used, but this may be relevant for more complex envs.
  21. prep = get_preprocessor(env.observation_space)(env.observation_space)
  22. print("The preprocessor is", prep)
  23. for eps_id in range(100):
  24. obs, info = env.reset()
  25. prev_action = np.zeros_like(env.action_space.sample())
  26. prev_reward = 0
  27. terminated = truncated = False
  28. t = 0
  29. while not terminated and not truncated:
  30. action = env.action_space.sample()
  31. new_obs, rew, terminated, truncated, info = env.step(action)
  32. batch_builder.add_values(
  33. t=t,
  34. eps_id=eps_id,
  35. agent_index=0,
  36. obs=prep.transform(obs),
  37. actions=action,
  38. action_prob=1.0, # put the true action probability here
  39. action_logp=0.0,
  40. rewards=rew,
  41. prev_actions=prev_action,
  42. prev_rewards=prev_reward,
  43. terminateds=terminated,
  44. truncateds=truncated,
  45. infos=info,
  46. new_obs=prep.transform(new_obs),
  47. )
  48. obs = new_obs
  49. prev_action = action
  50. prev_reward = rew
  51. t += 1
  52. writer.write(batch_builder.build_and_reset())
  53. # __sphinx_doc_end__