debug_counter_env.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import gymnasium as gym
  2. import numpy as np
  3. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  4. class DebugCounterEnv(gym.Env):
  5. """Simple Env that yields a ts counter as observation (0-based).
  6. Actions have no effect.
  7. The episode length is always 15.
  8. Reward is always: current ts % 3.
  9. """
  10. def __init__(self, config=None):
  11. config = config or {}
  12. self.action_space = gym.spaces.Discrete(2)
  13. self.observation_space = gym.spaces.Box(0, 100, (1,), dtype=np.float32)
  14. self.start_at_t = int(config.get("start_at_t", 0))
  15. self.i = self.start_at_t
  16. def reset(self, *, seed=None, options=None):
  17. self.i = self.start_at_t
  18. return self._get_obs(), {}
  19. def step(self, action):
  20. self.i += 1
  21. terminated = False
  22. truncated = self.i >= 15 + self.start_at_t
  23. return self._get_obs(), float(self.i % 3), terminated, truncated, {}
  24. def _get_obs(self):
  25. return np.array([self.i], dtype=np.float32)
  26. class MultiAgentDebugCounterEnv(MultiAgentEnv):
  27. def __init__(self, config):
  28. super().__init__()
  29. self.num_agents = config["num_agents"]
  30. self.base_episode_len = config.get("base_episode_len", 103)
  31. # Actions are always:
  32. # (episodeID, envID) as floats.
  33. self.action_space = gym.spaces.Box(-float("inf"), float("inf"), shape=(2,))
  34. # Observation dims:
  35. # 0=agent ID.
  36. # 1=episode ID (0.0 for obs after reset).
  37. # 2=env ID (0.0 for obs after reset).
  38. # 3=ts (of the agent).
  39. self.observation_space = gym.spaces.Box(float("-inf"), float("inf"), (4,))
  40. self.timesteps = [0] * self.num_agents
  41. self.terminateds = set()
  42. self.truncateds = set()
  43. self._skip_env_checking = True
  44. def reset(self, *, seed=None, options=None):
  45. self.timesteps = [0] * self.num_agents
  46. self.terminateds = set()
  47. self.truncateds = set()
  48. return {
  49. i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
  50. for i in range(self.num_agents)
  51. }, {}
  52. def step(self, action_dict):
  53. obs, rew, terminated, truncated = {}, {}, {}, {}
  54. for i, action in action_dict.items():
  55. self.timesteps[i] += 1
  56. obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
  57. rew[i] = self.timesteps[i] % 3
  58. terminated[i] = False
  59. truncated[i] = (
  60. True if self.timesteps[i] > self.base_episode_len + i else False
  61. )
  62. if terminated[i]:
  63. self.terminateds.add(i)
  64. if truncated[i]:
  65. self.truncateds.add(i)
  66. terminated["__all__"] = len(self.terminateds) == self.num_agents
  67. truncated["__all__"] = len(self.truncateds) == self.num_agents
  68. return obs, rew, terminated, truncated, {}