12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import gym
- from gym import spaces
- import numpy as np
- try:
- from dm_env import specs
- except ImportError:
- specs = None
- def _convert_spec_to_space(spec):
- if isinstance(spec, dict):
- return spaces.Dict(
- {k: _convert_spec_to_space(v)
- for k, v in spec.items()})
- if isinstance(spec, specs.DiscreteArray):
- return spaces.Discrete(spec.num_values)
- elif isinstance(spec, specs.BoundedArray):
- return spaces.Box(
- low=np.asscalar(spec.minimum),
- high=np.asscalar(spec.maximum),
- shape=spec.shape,
- dtype=spec.dtype)
- elif isinstance(spec, specs.Array):
- return spaces.Box(
- low=-float("inf"),
- high=float("inf"),
- shape=spec.shape,
- dtype=spec.dtype)
- raise NotImplementedError(
- ("Could not convert `Array` spec of type {} to Gym space. "
- "Attempted to convert: {}").format(type(spec), spec))
- class DMEnv(gym.Env):
- """A `gym.Env` wrapper for the `dm_env` API.
- """
- metadata = {"render.modes": ["rgb_array"]}
- def __init__(self, dm_env):
- super(DMEnv, self).__init__()
- self._env = dm_env
- self._prev_obs = None
- if specs is None:
- raise RuntimeError((
- "The `specs` module from `dm_env` was not imported. Make sure "
- "`dm_env` is installed and visible in the current python "
- "environment."))
- def step(self, action):
- ts = self._env.step(action)
- reward = ts.reward
- if reward is None:
- reward = 0.
- return ts.observation, reward, ts.last(), {"discount": ts.discount}
- def reset(self):
- ts = self._env.reset()
- return ts.observation
- def render(self, mode="rgb_array"):
- if self._prev_obs is None:
- raise ValueError(
- "Environment not started. Make sure to reset before rendering."
- )
- if mode == "rgb_array":
- return self._prev_obs
- else:
- raise NotImplementedError(
- "Render mode '{}' is not supported.".format(mode))
- @property
- def action_space(self):
- spec = self._env.action_spec()
- return _convert_spec_to_space(spec)
- @property
- def observation_space(self):
- spec = self._env.observation_spec()
- return _convert_spec_to_space(spec)
- @property
- def reward_range(self):
- spec = self._env.reward_spec()
- if isinstance(spec, specs.BoundedArray):
- return spec.minimum, spec.maximum
- return -float("inf"), float("inf")
|