test_gpus.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import unittest
  2. import ray
  3. from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
  4. from ray.rllib.utils.framework import try_import_torch
  5. from ray.rllib.utils.test_utils import framework_iterator
  6. from ray import tune
  7. torch, _ = try_import_torch()
  8. class TestGPUs(unittest.TestCase):
  9. def test_gpus_in_non_local_mode(self):
  10. # Non-local mode.
  11. ray.init(num_cpus=8)
  12. actual_gpus = torch.cuda.device_count()
  13. print(f"Actual GPUs found (by torch): {actual_gpus}")
  14. config = DEFAULT_CONFIG.copy()
  15. config["num_workers"] = 2
  16. config["env"] = "CartPole-v0"
  17. # Expect errors when we run a config w/ num_gpus>0 w/o a GPU
  18. # and _fake_gpus=False.
  19. for num_gpus in [0, 0.1, 1, actual_gpus + 4]:
  20. # Only allow possible num_gpus_per_worker (so test would not
  21. # block infinitely due to a down worker).
  22. per_worker = [0] if actual_gpus == 0 or actual_gpus < num_gpus \
  23. else [0, 0.5, 1]
  24. for num_gpus_per_worker in per_worker:
  25. for fake_gpus in [False] + ([] if num_gpus == 0 else [True]):
  26. config["num_gpus"] = num_gpus
  27. config["num_gpus_per_worker"] = num_gpus_per_worker
  28. config["_fake_gpus"] = fake_gpus
  29. print(f"\n------------\nnum_gpus={num_gpus} "
  30. f"num_gpus_per_worker={num_gpus_per_worker} "
  31. f"_fake_gpus={fake_gpus}")
  32. frameworks = ("tf", "torch") if num_gpus > 1 else \
  33. ("tf2", "tf", "torch")
  34. for _ in framework_iterator(config, frameworks=frameworks):
  35. # Expect that trainer creation causes a num_gpu error.
  36. if actual_gpus < num_gpus + 2 * num_gpus_per_worker \
  37. and not fake_gpus:
  38. # "Direct" RLlib (create Trainer on the driver).
  39. # Cannot run through ray.tune.run() as it would
  40. # simply wait infinitely for the resources to
  41. # become available.
  42. print("direct RLlib")
  43. self.assertRaisesRegex(
  44. RuntimeError,
  45. "Found 0 GPUs on your machine",
  46. lambda: PGTrainer(config, env="CartPole-v0"),
  47. )
  48. # If actual_gpus >= num_gpus or faked,
  49. # expect no error.
  50. else:
  51. print("direct RLlib")
  52. trainer = PGTrainer(config, env="CartPole-v0")
  53. trainer.stop()
  54. # Cannot run through ray.tune.run() w/ fake GPUs
  55. # as it would simply wait infinitely for the
  56. # resources to become available (even though, we
  57. # wouldn't really need them).
  58. if num_gpus == 0:
  59. print("via ray.tune.run()")
  60. tune.run(
  61. "PG",
  62. config=config,
  63. stop={"training_iteration": 0})
  64. ray.shutdown()
  65. def test_gpus_in_local_mode(self):
  66. # Local mode.
  67. ray.init(num_gpus=8, local_mode=True)
  68. actual_gpus_available = torch.cuda.device_count()
  69. config = DEFAULT_CONFIG.copy()
  70. config["num_workers"] = 2
  71. config["env"] = "CartPole-v0"
  72. # Expect no errors in local mode.
  73. for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]:
  74. print(f"num_gpus={num_gpus}")
  75. for fake_gpus in [False, True]:
  76. print(f"_fake_gpus={fake_gpus}")
  77. config["num_gpus"] = num_gpus
  78. config["_fake_gpus"] = fake_gpus
  79. frameworks = ("tf", "torch") if num_gpus > 1 else \
  80. ("tf2", "tf", "torch")
  81. for _ in framework_iterator(config, frameworks=frameworks):
  82. print("direct RLlib")
  83. trainer = PGTrainer(config, env="CartPole-v0")
  84. trainer.stop()
  85. print("via ray.tune.run()")
  86. tune.run(
  87. "PG", config=config, stop={"training_iteration": 0})
  88. ray.shutdown()
  89. if __name__ == "__main__":
  90. import pytest
  91. import sys
  92. sys.exit(pytest.main(["-v", __file__]))