env_using_remote_actor.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """
  2. Example of an environment that uses a named remote actor as parameter
  3. server.
  4. """
  5. from gym.envs.classic_control.cartpole import CartPoleEnv
  6. from gym.utils import seeding
  7. import ray
  8. @ray.remote
  9. class ParameterStorage:
  10. def get_params(self, rng):
  11. return {
  12. "MASSCART": rng.uniform(low=0.5, high=2.0),
  13. }
  14. class CartPoleWithRemoteParamServer(CartPoleEnv):
  15. """CartPoleMassEnv varies the weights of the cart and the pole.
  16. """
  17. def __init__(self, env_config):
  18. self.env_config = env_config
  19. super().__init__()
  20. # Get our param server (remote actor) by name.
  21. self._handler = ray.get_actor(
  22. env_config.get("param_server", "param-server"))
  23. self.rng_seed = None
  24. self.np_random, _ = seeding.np_random(self.rng_seed)
  25. def seed(self, rng_seed: int = None):
  26. if not rng_seed:
  27. return
  28. print(f"Seeding env (worker={self.env_config.worker_index}) "
  29. f"with {rng_seed}")
  30. self.rng_seed = rng_seed
  31. self.np_random, _ = seeding.np_random(rng_seed)
  32. def reset(self):
  33. # Pass in our RNG to guarantee no race conditions.
  34. # If `self._handler` had its own RNG, this may clash with other
  35. # envs trying to use the same param-server.
  36. params = ray.get(self._handler.get_params.remote(self.np_random))
  37. # IMPORTANT: Advance the state of our RNG (self._rng was passed
  38. # above via ray (serialized) and thus not altered locally here!).
  39. # Or create a new RNG from another random number:
  40. # Seed the RNG with a deterministic seed if set, otherwise, create
  41. # a random one.
  42. new_seed = (self.np_random.randint(0, 1000000)
  43. if not self.rng_seed else self.rng_seed)
  44. self.np_random, _ = seeding.np_random(new_seed)
  45. print(f"Env worker-idx={self.env_config.worker_index} "
  46. f"mass={params['MASSCART']}")
  47. self.masscart = params["MASSCART"]
  48. self.total_mass = (self.masspole + self.masscart)
  49. self.polemass_length = (self.masspole * self.length)
  50. return super().reset()