npu_accelerator.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import importlib
  5. import inspect
  6. from .abstract_accelerator import DeepSpeedAccelerator
  7. # During setup stage torch may not be installed, pass on no torch will
  8. # allow op builder related API to be executed.
  9. try:
  10. import torch.npu
  11. except ImportError:
  12. pass
  13. class NPU_Accelerator(DeepSpeedAccelerator):
  14. def __init__(self):
  15. super().__init__()
  16. self._name = 'npu'
  17. self._communication_backend_name = 'hccl'
  18. self._compile_backend = "inductor"
  19. # dict that holds class name <--> class type mapping i.e.
  20. # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
  21. # this dict will be filled at init stage
  22. self.class_dict = None
  23. def is_synchronized_device(self):
  24. return False
  25. def use_host_timers(self):
  26. return self.is_synchronized_device()
  27. def resolves_data_dependency(self):
  28. return self.is_synchronized_device()
  29. def handles_memory_backpressure(self):
  30. return self.is_synchronized_device()
  31. # Device APIs
  32. def device_name(self, device_index=None):
  33. if device_index is None:
  34. return 'npu'
  35. return 'npu:{}'.format(device_index)
  36. def device(self, device_index=None):
  37. return torch.npu.device(device_index)
  38. def set_device(self, device_index):
  39. torch.npu.set_device(device_index)
  40. def current_device(self):
  41. return torch.npu.current_device()
  42. def current_device_name(self):
  43. return 'npu:{}'.format(torch.npu.current_device())
  44. def device_count(self):
  45. return torch.npu.device_count()
  46. def synchronize(self, device_index=None):
  47. return torch.npu.synchronize(device_index)
  48. # RNG APIs
  49. def random(self):
  50. return torch.random
  51. def set_rng_state(self, new_state, device_index=None):
  52. if device_index is None:
  53. return torch.npu.set_rng_state(new_state)
  54. return torch.npu.set_rng_state(new_state, device_index)
  55. def get_rng_state(self, device_index=None):
  56. if device_index is None:
  57. return torch.npu.get_rng_state()
  58. return torch.npu.get_rng_state(device_index)
  59. def manual_seed(self, seed):
  60. return torch.npu.manual_seed(seed)
  61. def manual_seed_all(self, seed):
  62. return torch.npu.manual_seed_all(seed)
  63. def initial_seed(self):
  64. return torch.npu.initial_seed()
  65. def default_generator(self, device_index):
  66. return torch.npu.default_generators[device_index]
  67. # Streams/Events
  68. @property
  69. def Stream(self):
  70. return torch.npu.Stream
  71. def stream(self, stream):
  72. return torch.npu.stream(stream)
  73. def current_stream(self, device_index=None):
  74. return torch.npu.current_stream(device_index)
  75. def default_stream(self, device_index=None):
  76. return torch.npu.default_stream(device_index)
  77. @property
  78. def Event(self):
  79. return torch.npu.Event
  80. # Memory management
  81. def empty_cache(self):
  82. return torch.npu.empty_cache()
  83. def memory_allocated(self, device_index=None):
  84. return torch.npu.memory_allocated(device_index)
  85. def max_memory_allocated(self, device_index=None):
  86. return torch.npu.max_memory_allocated(device_index)
  87. def reset_max_memory_allocated(self, device_index=None):
  88. return torch.npu.reset_max_memory_allocated(device_index)
  89. def memory_cached(self, device_index=None):
  90. return torch.npu.memory_cached(device_index)
  91. def max_memory_cached(self, device_index=None):
  92. return torch.npu.max_memory_cached(device_index)
  93. def reset_max_memory_cached(self, device_index=None):
  94. return torch.npu.reset_max_memory_cached(device_index)
  95. def memory_stats(self, device_index=None):
  96. if hasattr(torch.npu, 'memory_stats'):
  97. return torch.npu.memory_stats(device_index)
  98. def reset_peak_memory_stats(self, device_index=None):
  99. if hasattr(torch.npu, 'reset_peak_memory_stats'):
  100. return torch.npu.reset_peak_memory_stats(device_index)
  101. def memory_reserved(self, device_index=None):
  102. if hasattr(torch.npu, 'memory_reserved'):
  103. return torch.npu.memory_reserved(device_index)
  104. def max_memory_reserved(self, device_index=None):
  105. if hasattr(torch.npu, 'max_memory_reserved'):
  106. return torch.npu.max_memory_reserved(device_index)
  107. def total_memory(self, device_index=None):
  108. return torch.npu.get_device_properties(device_index).total_memory
  109. def available_memory(self, device_index=None):
  110. return self.total_memory(device_index) - self.memory_allocated(device_index)
  111. # Data types
  112. def is_bf16_supported(self):
  113. return torch.npu.is_bf16_supported()
  114. def is_fp16_supported(self):
  115. return True
  116. def supported_dtypes(self):
  117. return [torch.float, torch.half, torch.bfloat16]
  118. # Misc
  119. def amp(self):
  120. if hasattr(torch.npu, 'amp'):
  121. return torch.npu.amp
  122. return None
  123. def is_available(self):
  124. return torch.npu.is_available()
  125. def range_push(self, msg):
  126. return
  127. def range_pop(self):
  128. return
  129. def lazy_call(self, callback):
  130. return torch.npu._lazy_call(callback)
  131. def communication_backend_name(self):
  132. return self._communication_backend_name
  133. def is_triton_supported(self):
  134. return False
  135. # Graph operations
  136. def create_graph(self):
  137. return None
  138. def capture_to_graph(self, graph, pool=None, stream=None):
  139. from deepspeed.runtime.utils import noop_context
  140. return noop_context()
  141. def replay_graph(self, graph):
  142. return
  143. # Tensor operations
  144. @property
  145. def BFloat16Tensor(self):
  146. return torch.npu.BFloat16Tensor
  147. @property
  148. def ByteTensor(self):
  149. return torch.npu.ByteTensor
  150. @property
  151. def DoubleTensor(self):
  152. return torch.npu.DoubleTensor
  153. @property
  154. def FloatTensor(self):
  155. return torch.npu.FloatTensor
  156. @property
  157. def HalfTensor(self):
  158. return torch.npu.HalfTensor
  159. @property
  160. def IntTensor(self):
  161. return torch.npu.IntTensor
  162. @property
  163. def LongTensor(self):
  164. return torch.npu.LongTensor
  165. def pin_memory(self, tensor, align_bytes=1):
  166. return tensor.pin_memory()
  167. def is_pinned(self, tensor):
  168. return tensor.is_pinned()
  169. def on_accelerator(self, tensor):
  170. device_str = str(tensor.device)
  171. if device_str.startswith('npu:'):
  172. return True
  173. else:
  174. return False
  175. def op_builder_dir(self):
  176. try:
  177. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  178. # if successful this also means we're doing a local install and not JIT compile path
  179. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  180. return "op_builder.npu"
  181. except ImportError:
  182. return "deepspeed.ops.op_builder.npu"
  183. def _lazy_init_class_dict(self):
  184. if self.class_dict:
  185. return
  186. op_builder_module = importlib.import_module(self.op_builder_dir())
  187. # get op builder class from op_builder/npu/__init__.py
  188. self.class_dict = {}
  189. for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
  190. self.class_dict[class_name] = class_obj
  191. # create an instance of op builder and return, name specified by class_name
  192. def create_op_builder(self, class_name):
  193. builder_class = self.get_op_builder(class_name)
  194. return None if builder_class is None else builder_class()
  195. # return an op builder class, name specified by class_name
  196. def get_op_builder(self, class_name):
  197. self._lazy_init_class_dict()
  198. if class_name in self.class_dict:
  199. return self.class_dict[class_name]
  200. else:
  201. return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
  202. def build_extension(self):
  203. from torch.utils.cpp_extension import BuildExtension
  204. return BuildExtension
  205. def export_envs(self):
  206. return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']
  207. def visible_devices_envs(self):
  208. return ['ASCEND_RT_VISIBLE_DEVICES']
  209. def set_visible_devices_envs(self, current_env, local_accelerator_ids):
  210. for env in self.visible_devices_envs():
  211. current_env[env] = ",".join(map(str, local_accelerator_ids))
  212. def get_compile_backend(self):
  213. return self._compile_backend
  214. def set_compile_backend(self, backend):
  215. supported_backends = torch._dynamo.list_backends(exclude_tags=())
  216. if backend in supported_backends:
  217. self._compile_backend = backend
  218. else:
  219. raise ValueError(
  220. f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")