12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from copy import deepcopy
- import gym
- import numpy as np
- from gym.spaces import Discrete, Dict, Box
- class CartPole:
- """
- Wrapper for gym CartPole environment where the reward
- is accumulated to the end
- """
- def __init__(self, config=None):
- self.env = gym.make("CartPole-v0")
- self.action_space = Discrete(2)
- self.observation_space = Dict({
- "obs": self.env.observation_space,
- "action_mask": Box(low=0, high=1, shape=(self.action_space.n, ))
- })
- self.running_reward = 0
- def reset(self):
- self.running_reward = 0
- return {
- "obs": self.env.reset(),
- "action_mask": np.array([1, 1], dtype=np.float32)
- }
- def step(self, action):
- obs, rew, done, info = self.env.step(action)
- self.running_reward += rew
- score = self.running_reward if done else 0
- return {
- "obs": obs,
- "action_mask": np.array([1, 1], dtype=np.float32)
- }, score, done, info
- def set_state(self, state):
- self.running_reward = state[1]
- self.env = deepcopy(state[0])
- obs = np.array(list(self.env.unwrapped.state))
- return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)}
- def get_state(self):
- return deepcopy(self.env), self.running_reward
|