mbmpo_env.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from gym.envs.classic_control import PendulumEnv, CartPoleEnv
  2. import numpy as np
  3. # MuJoCo may not be installed.
  4. HalfCheetahEnv = HopperEnv = None
  5. try:
  6. from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
  7. except Exception:
  8. pass
  9. class CartPoleWrapper(CartPoleEnv):
  10. """Wrapper for the Cartpole-v0 environment.
  11. Adds an additional `reward` method for some model-based RL algos (e.g.
  12. MB-MPO).
  13. """
  14. def reward(self, obs, action, obs_next):
  15. # obs = batch * [pos, vel, angle, rotation_rate]
  16. x = obs_next[:, 0]
  17. theta = obs_next[:, 2]
  18. # 1.0 if we are still on, 0.0 if we are terminated due to bounds
  19. # (angular or x-axis) being breached.
  20. rew = 1.0 - ((x < -self.x_threshold) | (x > self.x_threshold) |
  21. (theta < -self.theta_threshold_radians) |
  22. (theta > self.theta_threshold_radians)).astype(np.float32)
  23. return rew
  24. class PendulumWrapper(PendulumEnv):
  25. """Wrapper for the Pendulum-v1 environment.
  26. Adds an additional `reward` method for some model-based RL algos (e.g.
  27. MB-MPO).
  28. """
  29. def reward(self, obs, action, obs_next):
  30. # obs = [cos(theta), sin(theta), dtheta/dt]
  31. # To get the angle back from obs: atan2(sin(theta), cos(theta)).
  32. theta = np.arctan2(
  33. np.clip(obs[:, 1], -1.0, 1.0), np.clip(obs[:, 0], -1.0, 1.0))
  34. # Do everything in (B,) space (single theta-, action- and
  35. # reward values).
  36. a = np.clip(action, -self.max_torque, self.max_torque)[0]
  37. costs = self.angle_normalize(theta) ** 2 + \
  38. 0.1 * obs[:, 2] ** 2 + 0.001 * (a ** 2)
  39. return -costs
  40. @staticmethod
  41. def angle_normalize(x):
  42. return (((x + np.pi) % (2 * np.pi)) - np.pi)
  43. class HalfCheetahWrapper(HalfCheetahEnv or object):
  44. """Wrapper for the MuJoCo HalfCheetah-v2 environment.
  45. Adds an additional `reward` method for some model-based RL algos (e.g.
  46. MB-MPO).
  47. """
  48. def reward(self, obs, action, obs_next):
  49. if obs.ndim == 2 and action.ndim == 2:
  50. assert obs.shape == obs_next.shape
  51. forward_vel = obs_next[:, 8]
  52. ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
  53. reward = forward_vel - ctrl_cost
  54. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  55. else:
  56. forward_vel = obs_next[8]
  57. ctrl_cost = 0.1 * np.square(action).sum()
  58. reward = forward_vel - ctrl_cost
  59. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  60. class HopperWrapper(HopperEnv or object):
  61. """Wrapper for the MuJoCo Hopper-v2 environment.
  62. Adds an additional `reward` method for some model-based RL algos (e.g.
  63. MB-MPO).
  64. """
  65. def reward(self, obs, action, obs_next):
  66. alive_bonus = 1.0
  67. assert obs.ndim == 2 and action.ndim == 2
  68. assert (obs.shape == obs_next.shape
  69. and action.shape[0] == obs.shape[0])
  70. vel = obs_next[:, 5]
  71. ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
  72. reward = vel + alive_bonus - ctrl_cost
  73. return np.minimum(np.maximum(-1000.0, reward), 1000.0)
  74. if __name__ == "__main__":
  75. env = PendulumWrapper()
  76. env.reset()
  77. for _ in range(100):
  78. env.step(env.action_space.sample())
  79. env.render()