dm_env_wrapper.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import gym
  2. from gym import spaces
  3. import numpy as np
  4. try:
  5. from dm_env import specs
  6. except ImportError:
  7. specs = None
  8. def _convert_spec_to_space(spec):
  9. if isinstance(spec, dict):
  10. return spaces.Dict(
  11. {k: _convert_spec_to_space(v)
  12. for k, v in spec.items()})
  13. if isinstance(spec, specs.DiscreteArray):
  14. return spaces.Discrete(spec.num_values)
  15. elif isinstance(spec, specs.BoundedArray):
  16. return spaces.Box(
  17. low=np.asscalar(spec.minimum),
  18. high=np.asscalar(spec.maximum),
  19. shape=spec.shape,
  20. dtype=spec.dtype)
  21. elif isinstance(spec, specs.Array):
  22. return spaces.Box(
  23. low=-float("inf"),
  24. high=float("inf"),
  25. shape=spec.shape,
  26. dtype=spec.dtype)
  27. raise NotImplementedError(
  28. ("Could not convert `Array` spec of type {} to Gym space. "
  29. "Attempted to convert: {}").format(type(spec), spec))
  30. class DMEnv(gym.Env):
  31. """A `gym.Env` wrapper for the `dm_env` API.
  32. """
  33. metadata = {"render.modes": ["rgb_array"]}
  34. def __init__(self, dm_env):
  35. super(DMEnv, self).__init__()
  36. self._env = dm_env
  37. self._prev_obs = None
  38. if specs is None:
  39. raise RuntimeError((
  40. "The `specs` module from `dm_env` was not imported. Make sure "
  41. "`dm_env` is installed and visible in the current python "
  42. "environment."))
  43. def step(self, action):
  44. ts = self._env.step(action)
  45. reward = ts.reward
  46. if reward is None:
  47. reward = 0.
  48. return ts.observation, reward, ts.last(), {"discount": ts.discount}
  49. def reset(self):
  50. ts = self._env.reset()
  51. return ts.observation
  52. def render(self, mode="rgb_array"):
  53. if self._prev_obs is None:
  54. raise ValueError(
  55. "Environment not started. Make sure to reset before rendering."
  56. )
  57. if mode == "rgb_array":
  58. return self._prev_obs
  59. else:
  60. raise NotImplementedError(
  61. "Render mode '{}' is not supported.".format(mode))
  62. @property
  63. def action_space(self):
  64. spec = self._env.action_spec()
  65. return _convert_spec_to_space(spec)
  66. @property
  67. def observation_space(self):
  68. spec = self._env.observation_spec()
  69. return _convert_spec_to_space(spec)
  70. @property
  71. def reward_range(self):
  72. spec = self._env.reward_spec()
  73. if isinstance(spec, specs.BoundedArray):
  74. return spec.minimum, spec.maximum
  75. return -float("inf"), float("inf")