real_accelerator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. try:
  6. # Importing logger currently requires that torch is installed, hence the try...except
  7. # TODO: Remove logger dependency on torch.
  8. from deepspeed.utils import logger as accel_logger
  9. except ImportError as e:
  10. accel_logger = None
  11. try:
  12. from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
  13. except ImportError as e:
  14. dsa1 = None
  15. try:
  16. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
  17. except ImportError as e:
  18. dsa2 = None
  19. SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'xpu.external', 'npu', 'mps', 'hpu', 'mlu']
  20. ds_accelerator = None
  21. def _validate_accelerator(accel_obj):
  22. # because abstract_accelerator has different path during
  23. # build time (accelerator.abstract_accelerator)
  24. # and run time (deepspeed.accelerator.abstract_accelerator)
  25. # and extension would import the
  26. # run time abstract_accelerator/DeepSpeedAccelerator as its base
  27. # class, so we need to compare accel_obj with both base class.
  28. # if accel_obj is instance of DeepSpeedAccelerator in one of
  29. # accelerator.abstractor_accelerator
  30. # or deepspeed.accelerator.abstract_accelerator, consider accel_obj
  31. # is a conforming object
  32. if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
  33. raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")
  34. # TODO: turn off is_available test since this breaks tests
  35. # assert accel_obj.is_available(), \
  36. # f'{accel_obj.__class__.__name__} accelerator fails is_available() test'
  37. def is_current_accelerator_supported():
  38. return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST
  39. def get_accelerator():
  40. global ds_accelerator
  41. if ds_accelerator is not None:
  42. return ds_accelerator
  43. accelerator_name = None
  44. ds_set_method = None
  45. # 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
  46. if "DS_ACCELERATOR" in os.environ.keys():
  47. accelerator_name = os.environ["DS_ACCELERATOR"]
  48. if accelerator_name == "xpu":
  49. try:
  50. import intel_extension_for_pytorch as ipex
  51. assert ipex._C._has_xpu(), "XPU_Accelerator requires an intel_extension_for_pytorch that supports XPU."
  52. except ImportError as e:
  53. raise ValueError(
  54. f"XPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
  55. elif accelerator_name == "xpu.external":
  56. try:
  57. import intel_extension_for_deepspeed # noqa: F401 # type: ignore
  58. except ImportError as e:
  59. raise ValueError(
  60. f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
  61. )
  62. elif accelerator_name == "cpu":
  63. pass
  64. elif accelerator_name == "npu":
  65. try:
  66. import torch_npu # noqa: F401 # type: ignore
  67. except ImportError as e:
  68. raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.")
  69. pass
  70. elif accelerator_name == "mps":
  71. try:
  72. import torch.mps
  73. # should use torch.mps.is_available() if it exists someday but this is used as proxy
  74. torch.mps.current_allocated_memory()
  75. except (RuntimeError, ImportError) as e:
  76. raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
  77. elif accelerator_name == "hpu":
  78. try:
  79. import habana_frameworks.torch.hpu # noqa: F401
  80. except ImportError as e:
  81. raise ValueError(
  82. f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
  83. elif accelerator_name == "mlu":
  84. try:
  85. import torch_mlu # noqa: F401
  86. except ImportError as e:
  87. raise ValueError(f"MLU_Accelerator requires torch_mlu, which is not installed on this system.")
  88. elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
  89. raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
  90. f'Value "{accelerator_name}" is not supported')
  91. ds_set_method = "override"
  92. # 2. If no override, detect which accelerator to use automatically
  93. if accelerator_name is None:
  94. # We need a way to choose among different accelerator types.
  95. # Currently we detect which accelerator extension is installed
  96. # in the environment and use it if the installing answer is True.
  97. # An alternative might be detect whether CUDA device is installed on
  98. # the system but this comes with two pitfalls:
  99. # 1. the system may not have torch pre-installed, so
  100. # get_accelerator().is_available() may not work.
  101. # 2. Some scenario like install on login node (without CUDA device)
  102. # and run on compute node (with CUDA device) may cause mismatch
  103. # between installation time and runtime.
  104. try:
  105. from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811 # type: ignore
  106. accelerator_name = "xpu.external"
  107. except ImportError as e:
  108. pass
  109. if accelerator_name is None:
  110. try:
  111. import intel_extension_for_pytorch as ipex
  112. if ipex._C._has_xpu():
  113. accelerator_name = "xpu"
  114. else:
  115. accelerator_name = "cpu"
  116. except ImportError as e:
  117. pass
  118. if accelerator_name is None:
  119. try:
  120. import torch_npu # noqa: F401,F811 # type: ignore
  121. accelerator_name = "npu"
  122. except ImportError as e:
  123. pass
  124. if accelerator_name is None:
  125. try:
  126. import torch.mps
  127. # should use torch.mps.is_available() if it exists someday but this is used as proxy
  128. torch.mps.current_allocated_memory()
  129. accelerator_name = "mps"
  130. except (RuntimeError, ImportError) as e:
  131. pass
  132. if accelerator_name is None:
  133. try:
  134. import habana_frameworks.torch.hpu # noqa: F401,F811
  135. accelerator_name = "hpu"
  136. except ImportError as e:
  137. pass
  138. if accelerator_name is None:
  139. try:
  140. import torch_mlu # noqa: F401,F811
  141. accelerator_name = "mlu"
  142. except ImportError as e:
  143. pass
  144. if accelerator_name is None:
  145. # borrow this log from PR#5084
  146. try:
  147. import torch
  148. # Determine if we are on a GPU or x86 CPU with torch.
  149. if torch.cuda.is_available(): #ignore-cuda
  150. accelerator_name = "cuda"
  151. else:
  152. if accel_logger is not None:
  153. accel_logger.warn(
  154. "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it."
  155. )
  156. accelerator_name = "cpu"
  157. except (RuntimeError, ImportError) as e:
  158. # TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection
  159. accelerator_name = "cuda"
  160. pass
  161. ds_set_method = "auto detect"
  162. # 3. Set ds_accelerator accordingly
  163. if accelerator_name == "cuda":
  164. from .cuda_accelerator import CUDA_Accelerator
  165. ds_accelerator = CUDA_Accelerator()
  166. elif accelerator_name == "cpu":
  167. from .cpu_accelerator import CPU_Accelerator
  168. ds_accelerator = CPU_Accelerator()
  169. elif accelerator_name == "xpu.external":
  170. # XPU_Accelerator is already imported in detection stage
  171. ds_accelerator = XPU_Accelerator()
  172. elif accelerator_name == "xpu":
  173. from .xpu_accelerator import XPU_Accelerator
  174. ds_accelerator = XPU_Accelerator()
  175. elif accelerator_name == "npu":
  176. from .npu_accelerator import NPU_Accelerator
  177. ds_accelerator = NPU_Accelerator()
  178. elif accelerator_name == "mps":
  179. from .mps_accelerator import MPS_Accelerator
  180. ds_accelerator = MPS_Accelerator()
  181. elif accelerator_name == 'hpu':
  182. from .hpu_accelerator import HPU_Accelerator
  183. ds_accelerator = HPU_Accelerator()
  184. elif accelerator_name == 'mlu':
  185. from .mlu_accelerator import MLU_Accelerator
  186. ds_accelerator = MLU_Accelerator()
  187. _validate_accelerator(ds_accelerator)
  188. if accel_logger is not None:
  189. accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
  190. return ds_accelerator
  191. def set_accelerator(accel_obj):
  192. global ds_accelerator
  193. _validate_accelerator(accel_obj)
  194. if accel_logger is not None:
  195. accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
  196. ds_accelerator = accel_obj
  197. """
  198. -----------[code] test_get.py -----------
  199. from deepspeed.accelerator import get_accelerator
  200. my_accelerator = get_accelerator()
  201. logger.info(f'{my_accelerator._name=}')
  202. logger.info(f'{my_accelerator._communication_backend=}')
  203. logger.info(f'{my_accelerator.HalfTensor().device=}')
  204. logger.info(f'{my_accelerator.total_memory()=}')
  205. -----------[code] test_get.py -----------
  206. ---[output] python test_get.py---------
  207. my_accelerator.name()='cuda'
  208. my_accelerator.communication_backend='nccl'
  209. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  210. my_accelerator.total_memory()=34089730048
  211. ---[output] python test_get.py---------
  212. **************************************************************************
  213. -----------[code] test_set.py -----------
  214. from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
  215. cu_accel = CUDA_Accelerator()
  216. logger.info(f'{id(cu_accel)=}')
  217. from deepspeed.accelerator import set_accelerator, get_accelerator
  218. set_accelerator(cu_accel)
  219. my_accelerator = get_accelerator()
  220. logger.info(f'{id(my_accelerator)=}')
  221. logger.info(f'{my_accelerator._name=}')
  222. logger.info(f'{my_accelerator._communication_backend=}')
  223. logger.info(f'{my_accelerator.HalfTensor().device=}')
  224. logger.info(f'{my_accelerator.total_memory()=}')
  225. -----------[code] test_set.py -----------
  226. ---[output] python test_set.py---------
  227. id(cu_accel)=139648165478304
  228. my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
  229. my_accelerator.name='cuda'
  230. my_accelerator.communication_backend='nccl'
  231. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  232. my_accelerator.total_memory()=34089730048
  233. ---[output] python test_set.py---------
  234. """