gpu_requiring_env.py 1.1 KB

123456789101112131415161718192021222324252627
  1. import ray
  2. from ray.rllib.examples.env.simple_corridor import SimpleCorridor
  3. class GPURequiringEnv(SimpleCorridor):
  4. """A dummy env that requires a GPU in order to work.
  5. The env here is a simple corridor env that additionally simulates a GPU
  6. check in its constructor via `ray.get_gpu_ids()`. If this returns an
  7. empty list, we raise an error.
  8. To make this env work, use `num_gpus_per_worker > 0` (RolloutWorkers
  9. requesting this many GPUs each) and - maybe - `num_gpus > 0` in case
  10. your local worker/driver must have an env as well. However, this is
  11. only the case if `create_env_on_driver`=True (default is False).
  12. """
  13. def __init__(self, config=None):
  14. super().__init__(config)
  15. # Fake-require some GPUs (at least one).
  16. # If your local worker's env (`create_env_on_driver`=True) does not
  17. # necessarily require a GPU, you can perform the below assertion only
  18. # if `config.worker_index != 0`.
  19. gpus_available = ray.get_gpu_ids()
  20. assert len(gpus_available) > 0, "Not enough GPUs for this env!"
  21. print("Env can see these GPUs: {}".format(gpus_available))