action_mask_env.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from gymnasium.spaces import Box, Dict, Discrete
  2. import numpy as np
  3. from ray.rllib.examples.env.random_env import RandomEnv
  4. class ActionMaskEnv(RandomEnv):
  5. """A randomly acting environment that publishes an action-mask each step."""
  6. def __init__(self, config):
  7. super().__init__(config)
  8. self._skip_env_checking = True
  9. # Masking only works for Discrete actions.
  10. assert isinstance(self.action_space, Discrete)
  11. # Add action_mask to observations.
  12. self.observation_space = Dict(
  13. {
  14. "action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
  15. "observations": self.observation_space,
  16. }
  17. )
  18. self.valid_actions = None
  19. def reset(self, *, seed=None, options=None):
  20. obs, info = super().reset()
  21. self._fix_action_mask(obs)
  22. return obs, info
  23. def step(self, action):
  24. # Check whether action is valid.
  25. if not self.valid_actions[action]:
  26. raise ValueError(
  27. f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
  28. )
  29. obs, rew, done, truncated, info = super().step(action)
  30. self._fix_action_mask(obs)
  31. return obs, rew, done, truncated, info
  32. def _fix_action_mask(self, obs):
  33. # Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
  34. self.valid_actions = np.round(obs["action_mask"])
  35. obs["action_mask"] = self.valid_actions