utils.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from gym import wrappers
  2. import os
  3. from ray.rllib.env.env_context import EnvContext
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. from ray.rllib.utils import add_mixins
  6. from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
  7. def gym_env_creator(env_context: EnvContext, env_descriptor: str):
  8. """Tries to create a gym env given an EnvContext object and descriptor.
  9. Note: This function tries to construct the env from a string descriptor
  10. only using possibly installed RL env packages (such as gym, pybullet_envs,
  11. vizdoomgym, etc..). These packages are no installation requirements for
  12. RLlib. In case you would like to support more such env packages, add the
  13. necessary imports and construction logic below.
  14. Args:
  15. env_context (EnvContext): The env context object to configure the env.
  16. Note that this is a config dict, plus the properties:
  17. `worker_index`, `vector_index`, and `remote`.
  18. env_descriptor (str): The env descriptor, e.g. CartPole-v0,
  19. MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
  20. CartPoleContinuousBulletEnv-v0.
  21. Returns:
  22. gym.Env: The actual gym environment object.
  23. Raises:
  24. gym.error.Error: If the env cannot be constructed.
  25. """
  26. import gym
  27. # Allow for PyBullet or VizdoomGym envs to be used as well
  28. # (via string). This allows for doing things like
  29. # `env=CartPoleContinuousBulletEnv-v0` or
  30. # `env=VizdoomBasic-v0`.
  31. try:
  32. import pybullet_envs
  33. pybullet_envs.getList()
  34. except (ModuleNotFoundError, ImportError):
  35. pass
  36. try:
  37. import vizdoomgym
  38. vizdoomgym.__name__ # trick LINTer.
  39. except (ModuleNotFoundError, ImportError):
  40. pass
  41. # Try creating a gym env. If this fails we can output a
  42. # decent error message.
  43. try:
  44. return gym.make(env_descriptor, **env_context)
  45. except gym.error.Error:
  46. raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
  47. class VideoMonitor(wrappers.Monitor):
  48. # Same as original method, but doesn't use the StatsRecorder as it will
  49. # try to add up multi-agent rewards dicts, which throws errors.
  50. def _after_step(self, observation, reward, done, info):
  51. if not self.enabled:
  52. return done
  53. # Use done["__all__"] b/c this is a multi-agent dict.
  54. if done["__all__"] and self.env_semantics_autoreset:
  55. # For envs with BlockingReset wrapping VNCEnv, this observation
  56. # will be the first one of the new episode
  57. self.reset_video_recorder()
  58. self.episode_id += 1
  59. self._flush()
  60. # Record video
  61. self.video_recorder.capture_frame()
  62. return done
  63. def record_env_wrapper(env, record_env, log_dir, policy_config):
  64. if record_env:
  65. path_ = record_env if isinstance(record_env, str) else log_dir
  66. # Relative path: Add logdir here, otherwise, this would
  67. # not work for non-local workers.
  68. if not os.path.isabs(path_):
  69. path_ = os.path.join(log_dir, path_)
  70. print(f"Setting the path for recording to {path_}")
  71. wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
  72. else wrappers.Monitor
  73. wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
  74. env = wrapper_cls(
  75. env,
  76. path_,
  77. resume=True,
  78. force=True,
  79. video_callable=lambda _: True,
  80. mode="evaluation"
  81. if policy_config["in_evaluation"] else "training")
  82. return env