curriculum_capable_env.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import gym
  2. import random
  3. from ray.rllib.env.apis.task_settable_env import TaskSettableEnv
  4. from ray.rllib.env.env_context import EnvContext
  5. from ray.rllib.utils.annotations import override
  6. class CurriculumCapableEnv(TaskSettableEnv):
  7. """Example of a curriculum learning capable env.
  8. This simply wraps a FrozenLake-v1 env and makes it harder with each
  9. task. Task (difficulty levels) can range from 1 to 10."""
  10. # Defining the different maps (all same size) for the different
  11. # tasks. Theme here is to move the goal further and further away and
  12. # to add more and more holes along the way.
  13. MAPS = [
  14. ["SFFFFFF", "FFFFFFF", "FFFFFFF", "HHFFFFG", "FFFFFFF", "FFFFFFF"],
  15. ["SFFFFFF", "FFFHFFF", "FFFFFFF", "HHHFFFF", "FFFFFFG", "FFFFFFF"],
  16. ["SFFFFFF", "FFHHFFF", "FFFFFFF", "HHHHFFF", "FFFFFFF", "FFFFFFG"],
  17. ["SFFFFFF", "FHHHFFF", "FFFFFFF", "HHHHHFF", "FFFFFFF", "FFFFFGF"],
  18. ["SFFFFFF", "FFFHHFF", "FHFFFFF", "HHHHHHF", "FFHFFHF", "FFFGFFF"],
  19. ]
  20. def __init__(self, config: EnvContext):
  21. self.cur_level = config.get("start_level", 1)
  22. self.max_timesteps = config.get("max_timesteps", 18)
  23. self.frozen_lake = None
  24. self._make_lake() # create self.frozen_lake
  25. self.observation_space = self.frozen_lake.observation_space
  26. self.action_space = self.frozen_lake.action_space
  27. self.switch_env = False
  28. self._timesteps = 0
  29. def reset(self):
  30. if self.switch_env:
  31. self.switch_env = False
  32. self._make_lake()
  33. self._timesteps = 0
  34. return self.frozen_lake.reset()
  35. def step(self, action):
  36. self._timesteps += 1
  37. s, r, d, i = self.frozen_lake.step(action)
  38. # Make rewards scale with the level exponentially:
  39. # Level 1: x1
  40. # Level 2: x10
  41. # Level 3: x100, etc..
  42. r *= 10**(self.cur_level - 1)
  43. if self._timesteps >= self.max_timesteps:
  44. d = True
  45. return s, r, d, i
  46. @override(TaskSettableEnv)
  47. def sample_tasks(self, n_tasks):
  48. """Implement this to sample n random tasks."""
  49. return [random.randint(1, 10) for _ in range(n_tasks)]
  50. @override(TaskSettableEnv)
  51. def get_task(self):
  52. """Implement this to get the current task (curriculum level)."""
  53. return self.cur_level
  54. @override(TaskSettableEnv)
  55. def set_task(self, task):
  56. """Implement this to set the task (curriculum level) for this env."""
  57. self.cur_level = task
  58. self.switch_env = True
  59. def _make_lake(self):
  60. self.frozen_lake = gym.make(
  61. "FrozenLake-v1",
  62. desc=self.MAPS[self.cur_level - 1],
  63. is_slippery=False)