test_gpus.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import unittest
  2. import ray
  3. from ray import air
  4. from ray.rllib.algorithms.a2c.a2c import A2CConfig
  5. from ray.rllib.utils.framework import try_import_torch
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. from ray import tune
  8. torch, _ = try_import_torch()
  9. class TestGPUs(unittest.TestCase):
  10. def test_gpus_in_non_local_mode(self):
  11. # Non-local mode.
  12. ray.init()
  13. actual_gpus = torch.cuda.device_count()
  14. print(f"Actual GPUs found (by torch): {actual_gpus}")
  15. config = A2CConfig().rollouts(num_rollout_workers=2).environment("CartPole-v1")
  16. # Expect errors when we run a config w/ num_gpus>0 w/o a GPU
  17. # and _fake_gpus=False.
  18. for num_gpus in [0, 0.1, 1, actual_gpus + 4]:
  19. # Only allow possible num_gpus_per_worker (so test would not
  20. # block infinitely due to a down worker).
  21. per_worker = (
  22. [0] if actual_gpus == 0 or actual_gpus < num_gpus else [0, 0.5, 1]
  23. )
  24. for num_gpus_per_worker in per_worker:
  25. for fake_gpus in [False] + ([] if num_gpus == 0 else [True]):
  26. config.resources(
  27. num_gpus=num_gpus,
  28. num_gpus_per_worker=num_gpus_per_worker,
  29. _fake_gpus=fake_gpus,
  30. )
  31. print(
  32. f"\n------------\nnum_gpus={num_gpus} "
  33. f"num_gpus_per_worker={num_gpus_per_worker} "
  34. f"_fake_gpus={fake_gpus}"
  35. )
  36. frameworks = (
  37. ("tf", "torch") if num_gpus > 1 else ("tf2", "tf", "torch")
  38. )
  39. for _ in framework_iterator(config, frameworks=frameworks):
  40. # Expect that Algorithm creation causes a num_gpu error.
  41. if (
  42. actual_gpus < num_gpus + 2 * num_gpus_per_worker
  43. and not fake_gpus
  44. ):
  45. # "Direct" RLlib (create Algorithm on the driver).
  46. # Cannot run through ray.tune.Tuner().fit() as it would
  47. # simply wait infinitely for the resources to
  48. # become available.
  49. print("direct RLlib")
  50. self.assertRaisesRegex(
  51. RuntimeError,
  52. "Found 0 GPUs on your machine",
  53. lambda: config.build(),
  54. )
  55. # If actual_gpus >= num_gpus or faked,
  56. # expect no error.
  57. else:
  58. print("direct RLlib")
  59. algo = config.build()
  60. algo.stop()
  61. # Cannot run through ray.tune.Tuner().fit() w/ fake GPUs
  62. # as it would simply wait infinitely for the
  63. # resources to become available (even though, we
  64. # wouldn't really need them).
  65. if num_gpus == 0:
  66. print("via ray.tune.Tuner().fit()")
  67. tune.Tuner(
  68. "A2C",
  69. param_space=config,
  70. run_config=air.RunConfig(
  71. stop={"training_iteration": 0}
  72. ),
  73. ).fit()
  74. ray.shutdown()
  75. def test_gpus_in_local_mode(self):
  76. # Local mode.
  77. ray.init(local_mode=True)
  78. actual_gpus_available = torch.cuda.device_count()
  79. config = A2CConfig().rollouts(num_rollout_workers=2).environment("CartPole-v1")
  80. # Expect no errors in local mode.
  81. for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]:
  82. print(f"num_gpus={num_gpus}")
  83. for fake_gpus in [False, True]:
  84. print(f"_fake_gpus={fake_gpus}")
  85. config.resources(num_gpus=num_gpus, _fake_gpus=fake_gpus)
  86. frameworks = ("tf", "torch") if num_gpus > 1 else ("tf2", "tf", "torch")
  87. for _ in framework_iterator(config, frameworks=frameworks):
  88. print("direct RLlib")
  89. algo = config.build()
  90. algo.stop()
  91. print("via ray.tune.Tuner().fit()")
  92. tune.Tuner(
  93. "A2C",
  94. param_space=config,
  95. run_config=air.RunConfig(stop={"training_iteration": 0}),
  96. ).fit()
  97. ray.shutdown()
  98. if __name__ == "__main__":
  99. import pytest
  100. import sys
  101. sys.exit(pytest.main(["-v", __file__]))