custom_gym_env.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # __rllib-custom-gym-env-begin__
  2. import gymnasium as gym
  3. import ray
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. class SimpleCorridor(gym.Env):
  6. def __init__(self, config):
  7. self.end_pos = config["corridor_length"]
  8. self.cur_pos = 0
  9. self.action_space = gym.spaces.Discrete(2) # right/left
  10. self.observation_space = gym.spaces.Discrete(self.end_pos)
  11. def reset(self, *, seed=None, options=None):
  12. self.cur_pos = 0
  13. return self.cur_pos, {}
  14. def step(self, action):
  15. if action == 0 and self.cur_pos > 0: # move right (towards goal)
  16. self.cur_pos -= 1
  17. elif action == 1: # move left (towards start)
  18. self.cur_pos += 1
  19. if self.cur_pos >= self.end_pos:
  20. return 0, 1.0, True, True, {}
  21. else:
  22. return self.cur_pos, -0.1, False, False, {}
  23. ray.init()
  24. config = PPOConfig().environment(SimpleCorridor, env_config={"corridor_length": 5})
  25. algo = config.build()
  26. for _ in range(3):
  27. print(algo.train())
  28. algo.stop()
  29. # __rllib-custom-gym-env-end__