parametric_actions_cartpole.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import gym
  2. from gym.spaces import Box, Dict, Discrete
  3. import numpy as np
  4. import random
  5. class ParametricActionsCartPole(gym.Env):
  6. """Parametric action version of CartPole.
  7. In this env there are only ever two valid actions, but we pretend there are
  8. actually up to `max_avail_actions` actions that can be taken, and the two
  9. valid actions are randomly hidden among this set.
  10. At each step, we emit a dict of:
  11. - the actual cart observation
  12. - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
  13. - the list of action embeddings (w/ zeroes for invalid actions) (e.g.,
  14. [[0, 0],
  15. [0, 0],
  16. [-0.2322, -0.2569],
  17. [0, 0],
  18. [0, 0],
  19. [0.7878, 1.2297]] for max_avail_actions=6)
  20. In a real environment, the actions embeddings would be larger than two
  21. units of course, and also there would be a variable number of valid actions
  22. per step instead of always [LEFT, RIGHT].
  23. """
  24. def __init__(self, max_avail_actions):
  25. # Use simple random 2-unit action embeddings for [LEFT, RIGHT]
  26. self.left_action_embed = np.random.randn(2)
  27. self.right_action_embed = np.random.randn(2)
  28. self.action_space = Discrete(max_avail_actions)
  29. self.wrapped = gym.make("CartPole-v0")
  30. self.observation_space = Dict({
  31. "action_mask": Box(
  32. 0, 1, shape=(max_avail_actions, ), dtype=np.float32),
  33. "avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
  34. "cart": self.wrapped.observation_space,
  35. })
  36. def update_avail_actions(self):
  37. self.action_assignments = np.array(
  38. [[0., 0.]] * self.action_space.n, dtype=np.float32)
  39. self.action_mask = np.array(
  40. [0.] * self.action_space.n, dtype=np.float32)
  41. self.left_idx, self.right_idx = random.sample(
  42. range(self.action_space.n), 2)
  43. self.action_assignments[self.left_idx] = self.left_action_embed
  44. self.action_assignments[self.right_idx] = self.right_action_embed
  45. self.action_mask[self.left_idx] = 1
  46. self.action_mask[self.right_idx] = 1
  47. def reset(self):
  48. self.update_avail_actions()
  49. return {
  50. "action_mask": self.action_mask,
  51. "avail_actions": self.action_assignments,
  52. "cart": self.wrapped.reset(),
  53. }
  54. def step(self, action):
  55. if action == self.left_idx:
  56. actual_action = 0
  57. elif action == self.right_idx:
  58. actual_action = 1
  59. else:
  60. raise ValueError(
  61. "Chosen action was not one of the non-zero action embeddings",
  62. action, self.action_assignments, self.action_mask,
  63. self.left_idx, self.right_idx)
  64. orig_obs, rew, done, info = self.wrapped.step(actual_action)
  65. self.update_avail_actions()
  66. self.action_mask = self.action_mask.astype(np.float32)
  67. obs = {
  68. "action_mask": self.action_mask,
  69. "avail_actions": self.action_assignments,
  70. "cart": orig_obs,
  71. }
  72. return obs, rew, done, info
  73. class ParametricActionsCartPoleNoEmbeddings(gym.Env):
  74. """Same as the above ParametricActionsCartPole.
  75. However, action embeddings are not published inside observations,
  76. but will be learnt by the model.
  77. At each step, we emit a dict of:
  78. - the actual cart observation
  79. - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
  80. - action embeddings (w/ "dummy embedding" for invalid actions) are
  81. outsourced in the model and will be learned.
  82. """
  83. def __init__(self, max_avail_actions):
  84. # Randomly set which two actions are valid and available.
  85. self.left_idx, self.right_idx = random.sample(
  86. range(max_avail_actions), 2)
  87. self.valid_avail_actions_mask = np.array(
  88. [0.] * max_avail_actions, dtype=np.float32)
  89. self.valid_avail_actions_mask[self.left_idx] = 1
  90. self.valid_avail_actions_mask[self.right_idx] = 1
  91. self.action_space = Discrete(max_avail_actions)
  92. self.wrapped = gym.make("CartPole-v0")
  93. self.observation_space = Dict({
  94. "valid_avail_actions_mask": Box(0, 1, shape=(max_avail_actions, )),
  95. "cart": self.wrapped.observation_space,
  96. })
  97. def reset(self):
  98. return {
  99. "valid_avail_actions_mask": self.valid_avail_actions_mask,
  100. "cart": self.wrapped.reset(),
  101. }
  102. def step(self, action):
  103. if action == self.left_idx:
  104. actual_action = 0
  105. elif action == self.right_idx:
  106. actual_action = 1
  107. else:
  108. raise ValueError(
  109. "Chosen action was not one of the non-zero action embeddings",
  110. action, self.valid_avail_actions_mask, self.left_idx,
  111. self.right_idx)
  112. orig_obs, rew, done, info = self.wrapped.step(actual_action)
  113. obs = {
  114. "valid_avail_actions_mask": self.valid_avail_actions_mask,
  115. "cart": orig_obs,
  116. }
  117. return obs, rew, done, info