exception_wrapper.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import logging
  2. import traceback
  3. import gym
  4. logger = logging.getLogger(__name__)
  5. class TooManyResetAttemptsException(Exception):
  6. def __init__(self, max_attempts: int):
  7. super().__init__(
  8. f"Reached the maximum number of attempts ({max_attempts}) "
  9. f"to reset an environment.")
  10. class ResetOnExceptionWrapper(gym.Wrapper):
  11. def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
  12. super().__init__(env)
  13. self.max_reset_attempts = max_reset_attempts
  14. def reset(self, **kwargs):
  15. attempt = 0
  16. while attempt < self.max_reset_attempts:
  17. try:
  18. return self.env.reset(**kwargs)
  19. except Exception:
  20. logger.error(traceback.format_exc())
  21. attempt += 1
  22. else:
  23. raise TooManyResetAttemptsException(self.max_reset_attempts)
  24. def step(self, action):
  25. try:
  26. return self.env.step(action)
  27. except Exception:
  28. logger.error(traceback.format_exc())
  29. return self.reset(), 0.0, False, {"__terminated__": True}