model_vector_env.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import logging
  2. import numpy as np
  3. from gym.spaces import Discrete
  4. from ray.rllib.utils.annotations import override
  5. from ray.rllib.env.vector_env import VectorEnv
  6. from ray.rllib.evaluation.rollout_worker import get_global_worker
  7. from ray.rllib.env.base_env import BaseEnv
  8. from ray.rllib.utils.typing import EnvType
  9. logger = logging.getLogger(__name__)
  10. def model_vector_env(env: EnvType) -> BaseEnv:
  11. """Returns a VectorizedEnv wrapper around the given environment.
  12. To obtain worker configs, one can call get_global_worker().
  13. Args:
  14. env (EnvType): The input environment (of any supported environment
  15. type) to be convert to a _VectorizedModelGymEnv (wrapped as
  16. an RLlib BaseEnv).
  17. Returns:
  18. BaseEnv: The BaseEnv converted input `env`.
  19. """
  20. worker = get_global_worker()
  21. worker_index = worker.worker_index
  22. if worker_index:
  23. env = _VectorizedModelGymEnv(
  24. make_env=worker.make_sub_env_fn,
  25. existing_envs=[env],
  26. num_envs=worker.num_envs,
  27. observation_space=env.observation_space,
  28. action_space=env.action_space,
  29. )
  30. return BaseEnv.to_base_env(
  31. env,
  32. make_env=worker.make_sub_env_fn,
  33. num_envs=worker.num_envs,
  34. remote_envs=False,
  35. remote_env_batch_wait_ms=0)
  36. class _VectorizedModelGymEnv(VectorEnv):
  37. """Vectorized Environment Wrapper for MB-MPO.
  38. Primary change is in the `vector_step` method, which calls the dynamics
  39. models for next_obs "calculation" (instead of the actual env). Also, the
  40. actual envs need to have two extra methods implemented: `reward(obs)` and
  41. (optionally) `done(obs)`. If `done` is not implemented, we will assume
  42. that episodes in the env do not terminate, ever.
  43. """
  44. def __init__(self,
  45. make_env=None,
  46. existing_envs=None,
  47. num_envs=1,
  48. *,
  49. observation_space=None,
  50. action_space=None,
  51. env_config=None):
  52. self.make_env = make_env
  53. self.envs = existing_envs
  54. self.num_envs = num_envs
  55. while len(self.envs) < num_envs:
  56. self.envs.append(self.make_env(len(self.envs)))
  57. super().__init__(
  58. observation_space=observation_space
  59. or self.envs[0].observation_space,
  60. action_space=action_space or self.envs[0].action_space,
  61. num_envs=num_envs)
  62. worker = get_global_worker()
  63. self.model, self.device = worker.foreach_policy(
  64. lambda x, y: (x.dynamics_model, x.device))[0]
  65. @override(VectorEnv)
  66. def vector_reset(self):
  67. """Override parent to store actual env obs for upcoming predictions.
  68. """
  69. self.cur_obs = [e.reset() for e in self.envs]
  70. return self.cur_obs
  71. @override(VectorEnv)
  72. def reset_at(self, index):
  73. """Override parent to store actual env obs for upcoming predictions.
  74. """
  75. obs = self.envs[index].reset()
  76. self.cur_obs[index] = obs
  77. return obs
  78. @override(VectorEnv)
  79. def vector_step(self, actions):
  80. if self.cur_obs is None:
  81. raise ValueError("Need to reset env first")
  82. # If discrete, need to one-hot actions
  83. if isinstance(self.action_space, Discrete):
  84. act = np.array(actions)
  85. new_act = np.zeros((act.size, act.max() + 1))
  86. new_act[np.arange(act.size), act] = 1
  87. actions = new_act.astype("float32")
  88. # Batch the TD-model prediction.
  89. obs_batch = np.stack(self.cur_obs, axis=0)
  90. action_batch = np.stack(actions, axis=0)
  91. # Predict the next observation, given previous a) real obs
  92. # (after a reset), b) predicted obs (any other time).
  93. next_obs_batch = self.model.predict_model_batches(
  94. obs_batch, action_batch, device=self.device)
  95. next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
  96. # Call env's reward function.
  97. # Note: Each actual env must implement one to output exact rewards.
  98. rew_batch = self.envs[0].reward(obs_batch, action_batch,
  99. next_obs_batch)
  100. # If env has a `done` method, use it.
  101. if hasattr(self.envs[0], "done"):
  102. dones_batch = self.envs[0].done(next_obs_batch)
  103. # Otherwise, assume the episode does not end.
  104. else:
  105. dones_batch = np.asarray([False for _ in range(self.num_envs)])
  106. info_batch = [{} for _ in range(self.num_envs)]
  107. self.cur_obs = next_obs_batch
  108. return list(next_obs_batch), list(rew_batch), list(
  109. dones_batch), info_batch
  110. @override(VectorEnv)
  111. def get_sub_environments(self):
  112. return self.envs