cartpole.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from copy import deepcopy
  2. import gym
  3. import numpy as np
  4. from gym.spaces import Discrete, Dict, Box
  5. class CartPole:
  6. """
  7. Wrapper for gym CartPole environment where the reward
  8. is accumulated to the end
  9. """
  10. def __init__(self, config=None):
  11. self.env = gym.make("CartPole-v0")
  12. self.action_space = Discrete(2)
  13. self.observation_space = Dict({
  14. "obs": self.env.observation_space,
  15. "action_mask": Box(low=0, high=1, shape=(self.action_space.n, ))
  16. })
  17. self.running_reward = 0
  18. def reset(self):
  19. self.running_reward = 0
  20. return {
  21. "obs": self.env.reset(),
  22. "action_mask": np.array([1, 1], dtype=np.float32)
  23. }
  24. def step(self, action):
  25. obs, rew, done, info = self.env.step(action)
  26. self.running_reward += rew
  27. score = self.running_reward if done else 0
  28. return {
  29. "obs": obs,
  30. "action_mask": np.array([1, 1], dtype=np.float32)
  31. }, score, done, info
  32. def set_state(self, state):
  33. self.running_reward = state[1]
  34. self.env = deepcopy(state[0])
  35. obs = np.array(list(self.env.unwrapped.state))
  36. return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)}
  37. def get_state(self):
  38. return deepcopy(self.env), self.running_reward