mbmpo_env.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import numpy as np
  2. import gymnasium as gym
  3. # MuJoCo may not be installed.
  4. HalfCheetahEnv = HopperEnv = None
  5. try:
  6. from gymnasium.envs.mujoco import HalfCheetahEnv, HopperEnv
  7. except Exception:
  8. pass
  9. class CartPoleWrapper(gym.Wrapper):
  10. """Wrapper for the CartPole-v1 environment.
  11. Adds an additional `reward` method for some model-based RL algos (e.g.
  12. MB-MPO).
  13. """
  14. # This is required by MB-MPO's model vector env so that it knows how many
  15. # steps to simulate the env for.
  16. _max_episode_steps = 500
  17. def __init__(self, **kwargs):
  18. env = gym.make("CartPole-v1", **kwargs)
  19. gym.Wrapper.__init__(self, env)
  20. def reward(self, obs, action, obs_next):
  21. # obs = batch * [pos, vel, angle, rotation_rate]
  22. x = obs_next[:, 0]
  23. theta = obs_next[:, 2]
  24. # 1.0 if we are still on, 0.0 if we are terminated due to bounds
  25. # (angular or x-axis) being breached.
  26. rew = 1.0 - (
  27. (x < -self.x_threshold)
  28. | (x > self.x_threshold)
  29. | (theta < -self.theta_threshold_radians)
  30. | (theta > self.theta_threshold_radians)
  31. ).astype(np.float32)
  32. return rew
  33. class PendulumWrapper(gym.Wrapper):
  34. """Wrapper for the Pendulum-v1 environment.
  35. Adds an additional `reward` method for some model-based RL algos (e.g.
  36. MB-MPO).
  37. """
  38. # This is required by MB-MPO's model vector env so that it knows how many
  39. # steps to simulate the env for.
  40. _max_episode_steps = 200
  41. def __init__(self, **kwargs):
  42. env = gym.make("Pendulum-v1", **kwargs)
  43. gym.Wrapper.__init__(self, env)
  44. def reward(self, obs, action, obs_next):
  45. # obs = [cos(theta), sin(theta), dtheta/dt]
  46. # To get the angle back from obs: atan2(sin(theta), cos(theta)).
  47. theta = np.arctan2(np.clip(obs[:, 1], -1.0, 1.0), np.clip(obs[:, 0], -1.0, 1.0))
  48. # Do everything in (B,) space (single theta-, action- and
  49. # reward values).
  50. a = np.clip(action, -self.max_torque, self.max_torque)[0]
  51. costs = (
  52. self.angle_normalize(theta) ** 2 + 0.1 * obs[:, 2] ** 2 + 0.001 * (a**2)
  53. )
  54. return -costs
  55. @staticmethod
  56. def angle_normalize(x):
  57. return ((x + np.pi) % (2 * np.pi)) - np.pi
  58. class HalfCheetahWrapper(HalfCheetahEnv or object):
  59. """Wrapper for the MuJoCo HalfCheetah-v2 environment.
  60. Adds an additional `reward` method for some model-based RL algos (e.g.
  61. MB-MPO).
  62. """
  63. def reward(self, obs, action, obs_next):
  64. if obs.ndim == 2 and action.ndim == 2:
  65. assert obs.shape == obs_next.shape
  66. forward_vel = obs_next[:, 8]
  67. ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
  68. reward = forward_vel - ctrl_cost
  69. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  70. else:
  71. forward_vel = obs_next[8]
  72. ctrl_cost = 0.1 * np.square(action).sum()
  73. reward = forward_vel - ctrl_cost
  74. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  75. class HopperWrapper(HopperEnv or object):
  76. """Wrapper for the MuJoCo Hopper-v2 environment.
  77. Adds an additional `reward` method for some model-based RL algos (e.g.
  78. MB-MPO).
  79. """
  80. # This is required by MB-MPO's model vector env so that it knows how many
  81. # steps to simulate the env for.
  82. _max_episode_steps = 1000
  83. def reward(self, obs, action, obs_next):
  84. alive_bonus = 1.0
  85. assert obs.ndim == 2 and action.ndim == 2
  86. assert obs.shape == obs_next.shape and action.shape[0] == obs.shape[0]
  87. vel = obs_next[:, 5]
  88. ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
  89. reward = vel + alive_bonus - ctrl_cost
  90. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  91. if __name__ == "__main__":
  92. env = PendulumWrapper()
  93. env.reset()
  94. for _ in range(100):
  95. env.step(env.action_space.sample())
  96. env.render()