real_accelerator.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. try:
  6. # Importing logger currently requires that torch is installed, hence the try...except
  7. # TODO: Remove logger dependency on torch.
  8. from deepspeed.utils import logger as accel_logger
  9. except ImportError as e:
  10. accel_logger = None
  11. try:
  12. from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
  13. except ImportError as e:
  14. dsa1 = None
  15. try:
  16. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
  17. except ImportError as e:
  18. dsa2 = None
  19. SUPPORTED_ACCELERATOR_LIST = ['cuda', 'xeon', 'xpu', 'xpu.external', 'npu', 'mps', 'hpu']
  20. ds_accelerator = None
  21. def _validate_accelerator(accel_obj):
  22. # because abstract_accelerator has different path during
  23. # build time (accelerator.abstract_accelerator)
  24. # and run time (deepspeed.accelerator.abstract_accelerator)
  25. # and extension would import the
  26. # run time abstract_accelerator/DeepSpeedAccelerator as its base
  27. # class, so we need to compare accel_obj with both base class.
  28. # if accel_obj is instance of DeepSpeedAccelerator in one of
  29. # accelerator.abstractor_accelerator
  30. # or deepspeed.accelerator.abstract_accelerator, consider accel_obj
  31. # is a conforming object
  32. if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
  33. raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")
  34. # TODO: turn off is_available test since this breaks tests
  35. # assert accel_obj.is_available(), \
  36. # f'{accel_obj.__class__.__name__} accelerator fails is_available() test'
  37. def is_current_accelerator_supported():
  38. return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST
  39. def get_accelerator():
  40. global ds_accelerator
  41. if ds_accelerator is not None:
  42. return ds_accelerator
  43. accelerator_name = None
  44. ds_set_method = None
  45. # 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
  46. if "DS_ACCELERATOR" in os.environ.keys():
  47. accelerator_name = os.environ["DS_ACCELERATOR"]
  48. if accelerator_name == "xpu":
  49. try:
  50. import intel_extension_for_pytorch as ipex
  51. assert ipex._C._has_xpu(), "XPU_Accelerator requires an intel_extension_for_pytorch that supports XPU."
  52. except ImportError as e:
  53. raise ValueError(
  54. f"XPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
  55. elif accelerator_name == "xpu.external":
  56. try:
  57. import intel_extension_for_deepspeed # noqa: F401 # type: ignore
  58. except ImportError as e:
  59. raise ValueError(
  60. f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
  61. )
  62. elif accelerator_name == "xeon":
  63. try:
  64. import intel_extension_for_pytorch # noqa: F401 # type: ignore
  65. except ImportError as e:
  66. raise ValueError(
  67. f"Xeon_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
  68. elif accelerator_name == "npu":
  69. try:
  70. import torch_npu # noqa: F401 # type: ignore
  71. except ImportError as e:
  72. raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.")
  73. pass
  74. elif accelerator_name == "mps":
  75. try:
  76. import torch.mps
  77. # should use torch.mps.is_available() if it exists someday but this is used as proxy
  78. torch.mps.current_allocated_memory()
  79. except (RuntimeError, ImportError) as e:
  80. raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
  81. elif accelerator_name == "hpu":
  82. try:
  83. import habana_frameworks.torch.hpu # noqa: F401
  84. except ImportError as e:
  85. raise ValueError(
  86. f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
  87. elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
  88. raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
  89. f'Value "{accelerator_name}" is not supported')
  90. ds_set_method = "override"
  91. # 2. If no override, detect which accelerator to use automatically
  92. if accelerator_name is None:
  93. # We need a way to choose among different accelerator types.
  94. # Currently we detect which accelerator extension is installed
  95. # in the environment and use it if the installing answer is True.
  96. # An alternative might be detect whether CUDA device is installed on
  97. # the system but this comes with two pitfalls:
  98. # 1. the system may not have torch pre-installed, so
  99. # get_accelerator().is_available() may not work.
  100. # 2. Some scenario like install on login node (without CUDA device)
  101. # and run on compute node (with CUDA device) may cause mismatch
  102. # between installation time and runtime.
  103. try:
  104. from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811 # type: ignore
  105. accelerator_name = "xpu.external"
  106. except ImportError as e:
  107. pass
  108. if accelerator_name is None:
  109. try:
  110. import intel_extension_for_pytorch as ipex
  111. if ipex._C._has_xpu():
  112. accelerator_name = "xpu"
  113. else:
  114. accelerator_name = "xeon"
  115. except ImportError as e:
  116. pass
  117. if accelerator_name is None:
  118. try:
  119. import torch_npu # noqa: F401,F811 # type: ignore
  120. accelerator_name = "npu"
  121. except ImportError as e:
  122. pass
  123. if accelerator_name is None:
  124. try:
  125. import torch.mps
  126. # should use torch.mps.is_available() if it exists someday but this is used as proxy
  127. torch.mps.current_allocated_memory()
  128. accelerator_name = "mps"
  129. except (RuntimeError, ImportError) as e:
  130. pass
  131. if accelerator_name is None:
  132. try:
  133. import habana_frameworks.torch.hpu # noqa: F401,F811
  134. accelerator_name = "hpu"
  135. except ImportError as e:
  136. pass
  137. if accelerator_name is None:
  138. accelerator_name = "cuda"
  139. ds_set_method = "auto detect"
  140. # 3. Set ds_accelerator accordingly
  141. if accelerator_name == "cuda":
  142. from .cuda_accelerator import CUDA_Accelerator
  143. ds_accelerator = CUDA_Accelerator()
  144. elif accelerator_name == "xeon":
  145. from .xeon_accelerator import Xeon_Accelerator
  146. ds_accelerator = Xeon_Accelerator()
  147. elif accelerator_name == "xpu.external":
  148. # XPU_Accelerator is already imported in detection stage
  149. ds_accelerator = XPU_Accelerator()
  150. elif accelerator_name == "xpu":
  151. from .xpu_accelerator import XPU_Accelerator
  152. ds_accelerator = XPU_Accelerator()
  153. elif accelerator_name == "npu":
  154. from .npu_accelerator import NPU_Accelerator
  155. ds_accelerator = NPU_Accelerator()
  156. elif accelerator_name == "mps":
  157. from .mps_accelerator import MPS_Accelerator
  158. ds_accelerator = MPS_Accelerator()
  159. elif accelerator_name == 'hpu':
  160. from .hpu_accelerator import HPU_Accelerator
  161. ds_accelerator = HPU_Accelerator()
  162. _validate_accelerator(ds_accelerator)
  163. if accel_logger is not None:
  164. accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
  165. return ds_accelerator
  166. def set_accelerator(accel_obj):
  167. global ds_accelerator
  168. _validate_accelerator(accel_obj)
  169. if accel_logger is not None:
  170. accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
  171. ds_accelerator = accel_obj
  172. """
  173. -----------[code] test_get.py -----------
  174. from deepspeed.accelerator import get_accelerator
  175. my_accelerator = get_accelerator()
  176. logger.info(f'{my_accelerator._name=}')
  177. logger.info(f'{my_accelerator._communication_backend=}')
  178. logger.info(f'{my_accelerator.HalfTensor().device=}')
  179. logger.info(f'{my_accelerator.total_memory()=}')
  180. -----------[code] test_get.py -----------
  181. ---[output] python test_get.py---------
  182. my_accelerator.name()='cuda'
  183. my_accelerator.communication_backend='nccl'
  184. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  185. my_accelerator.total_memory()=34089730048
  186. ---[output] python test_get.py---------
  187. **************************************************************************
  188. -----------[code] test_set.py -----------
  189. from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
  190. cu_accel = CUDA_Accelerator()
  191. logger.info(f'{id(cu_accel)=}')
  192. from deepspeed.accelerator import set_accelerator, get_accelerator
  193. set_accelerator(cu_accel)
  194. my_accelerator = get_accelerator()
  195. logger.info(f'{id(my_accelerator)=}')
  196. logger.info(f'{my_accelerator._name=}')
  197. logger.info(f'{my_accelerator._communication_backend=}')
  198. logger.info(f'{my_accelerator.HalfTensor().device=}')
  199. logger.info(f'{my_accelerator.total_memory()=}')
  200. -----------[code] test_set.py -----------
  201. ---[output] python test_set.py---------
  202. id(cu_accel)=139648165478304
  203. my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
  204. my_accelerator.name='cuda'
  205. my_accelerator.communication_backend='nccl'
  206. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  207. my_accelerator.total_memory()=34089730048
  208. ---[output] python test_set.py---------
  209. """