12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from gym import wrappers
- import os
- from ray.rllib.env.env_context import EnvContext
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- from ray.rllib.utils import add_mixins
- from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
- def gym_env_creator(env_context: EnvContext, env_descriptor: str):
- """Tries to create a gym env given an EnvContext object and descriptor.
- Note: This function tries to construct the env from a string descriptor
- only using possibly installed RL env packages (such as gym, pybullet_envs,
- vizdoomgym, etc..). These packages are no installation requirements for
- RLlib. In case you would like to support more such env packages, add the
- necessary imports and construction logic below.
- Args:
- env_context (EnvContext): The env context object to configure the env.
- Note that this is a config dict, plus the properties:
- `worker_index`, `vector_index`, and `remote`.
- env_descriptor (str): The env descriptor, e.g. CartPole-v0,
- MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
- CartPoleContinuousBulletEnv-v0.
- Returns:
- gym.Env: The actual gym environment object.
- Raises:
- gym.error.Error: If the env cannot be constructed.
- """
- import gym
- # Allow for PyBullet or VizdoomGym envs to be used as well
- # (via string). This allows for doing things like
- # `env=CartPoleContinuousBulletEnv-v0` or
- # `env=VizdoomBasic-v0`.
- try:
- import pybullet_envs
- pybullet_envs.getList()
- except (ModuleNotFoundError, ImportError):
- pass
- try:
- import vizdoomgym
- vizdoomgym.__name__ # trick LINTer.
- except (ModuleNotFoundError, ImportError):
- pass
- # Try creating a gym env. If this fails we can output a
- # decent error message.
- try:
- return gym.make(env_descriptor, **env_context)
- except gym.error.Error:
- raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
- class VideoMonitor(wrappers.Monitor):
- # Same as original method, but doesn't use the StatsRecorder as it will
- # try to add up multi-agent rewards dicts, which throws errors.
- def _after_step(self, observation, reward, done, info):
- if not self.enabled:
- return done
- # Use done["__all__"] b/c this is a multi-agent dict.
- if done["__all__"] and self.env_semantics_autoreset:
- # For envs with BlockingReset wrapping VNCEnv, this observation
- # will be the first one of the new episode
- self.reset_video_recorder()
- self.episode_id += 1
- self._flush()
- # Record video
- self.video_recorder.capture_frame()
- return done
- def record_env_wrapper(env, record_env, log_dir, policy_config):
- if record_env:
- path_ = record_env if isinstance(record_env, str) else log_dir
- # Relative path: Add logdir here, otherwise, this would
- # not work for non-local workers.
- if not os.path.isabs(path_):
- path_ = os.path.join(log_dir, path_)
- print(f"Setting the path for recording to {path_}")
- wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
- else wrappers.Monitor
- wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
- env = wrapper_cls(
- env,
- path_,
- resume=True,
- force=True,
- video_callable=lambda _: True,
- mode="evaluation"
- if policy_config["in_evaluation"] else "training")
- return env
|