npu_accelerator.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import importlib
  5. import inspect
  6. from .abstract_accelerator import DeepSpeedAccelerator
  7. # During setup stage torch may not be installed, pass on no torch will
  8. # allow op builder related API to be executed.
  9. try:
  10. import torch.npu
  11. except ImportError:
  12. pass
  13. class NPU_Accelerator(DeepSpeedAccelerator):
  14. def __init__(self):
  15. super().__init__()
  16. self._name = 'npu'
  17. self._communication_backend_name = 'hccl'
  18. # dict that holds class name <--> class type mapping i.e.
  19. # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
  20. # this dict will be filled at init stage
  21. self.class_dict = None
  22. def is_synchronized_device(self):
  23. return False
  24. def use_host_timers(self):
  25. return self.is_synchronized_device()
  26. def resolves_data_dependency(self):
  27. return self.is_synchronized_device()
  28. def handles_memory_backpressure(self):
  29. return self.is_synchronized_device()
  30. # Device APIs
  31. def device_name(self, device_index=None):
  32. if device_index is None:
  33. return 'npu'
  34. return 'npu:{}'.format(device_index)
  35. def device(self, device_index=None):
  36. return torch.npu.device(device_index)
  37. def set_device(self, device_index):
  38. torch.npu.set_device(device_index)
  39. def current_device(self):
  40. return torch.npu.current_device()
  41. def current_device_name(self):
  42. return 'npu:{}'.format(torch.npu.current_device())
  43. def device_count(self):
  44. return torch.npu.device_count()
  45. def synchronize(self, device_index=None):
  46. return torch.npu.synchronize(device_index)
  47. # RNG APIs
  48. def random(self):
  49. return torch.random
  50. def set_rng_state(self, new_state, device_index=None):
  51. if device_index is None:
  52. return torch.npu.set_rng_state(new_state)
  53. return torch.npu.set_rng_state(new_state, device_index)
  54. def get_rng_state(self, device_index=None):
  55. if device_index is None:
  56. return torch.npu.get_rng_state()
  57. return torch.npu.get_rng_state(device_index)
  58. def manual_seed(self, seed):
  59. return torch.npu.manual_seed(seed)
  60. def manual_seed_all(self, seed):
  61. return torch.npu.manual_seed_all(seed)
  62. def initial_seed(self, seed):
  63. return torch.npu.initial_seed(seed)
  64. def default_generator(self, device_index):
  65. return torch.npu.default_generators[device_index]
  66. # Streams/Events
  67. @property
  68. def Stream(self):
  69. return torch.npu.Stream
  70. def stream(self, stream):
  71. return torch.npu.stream(stream)
  72. def current_stream(self, device_index=None):
  73. return torch.npu.current_stream(device_index)
  74. def default_stream(self, device_index=None):
  75. return torch.npu.default_stream(device_index)
  76. @property
  77. def Event(self):
  78. return torch.npu.Event
  79. # Memory management
  80. def empty_cache(self):
  81. return torch.npu.empty_cache()
  82. def memory_allocated(self, device_index=None):
  83. return torch.npu.memory_allocated(device_index)
  84. def max_memory_allocated(self, device_index=None):
  85. return torch.npu.max_memory_allocated(device_index)
  86. def reset_max_memory_allocated(self, device_index=None):
  87. return torch.npu.reset_max_memory_allocated(device_index)
  88. def memory_cached(self, device_index=None):
  89. return torch.npu.memory_cached(device_index)
  90. def max_memory_cached(self, device_index=None):
  91. return torch.npu.max_memory_cached(device_index)
  92. def reset_max_memory_cached(self, device_index=None):
  93. return torch.npu.reset_max_memory_cached(device_index)
  94. def memory_stats(self, device_index=None):
  95. if hasattr(torch.npu, 'memory_stats'):
  96. return torch.npu.memory_stats(device_index)
  97. def reset_peak_memory_stats(self, device_index=None):
  98. if hasattr(torch.npu, 'reset_peak_memory_stats'):
  99. return torch.npu.reset_peak_memory_stats(device_index)
  100. def memory_reserved(self, device_index=None):
  101. if hasattr(torch.npu, 'memory_reserved'):
  102. return torch.npu.memory_reserved(device_index)
  103. def max_memory_reserved(self, device_index=None):
  104. if hasattr(torch.npu, 'max_memory_reserved'):
  105. return torch.npu.max_memory_reserved(device_index)
  106. def total_memory(self, device_index=None):
  107. return torch.npu.get_device_properties(device_index).total_memory
  108. def available_memory(self, device_index=None):
  109. return self.total_memory(device_index) - self.memory_allocated(device_index)
  110. # Data types
  111. def is_bf16_supported(self):
  112. return torch.npu.is_bf16_supported()
  113. def is_fp16_supported(self):
  114. return True
  115. def supported_dtypes(self):
  116. return [torch.float, torch.half, torch.bfloat16]
  117. # Misc
  118. def amp(self):
  119. if hasattr(torch.npu, 'amp'):
  120. return torch.npu.amp
  121. return None
  122. def is_available(self):
  123. return torch.npu.is_available()
  124. def range_push(self, msg):
  125. return
  126. def range_pop(self):
  127. return
  128. def lazy_call(self, callback):
  129. return torch.npu._lazy_call(callback)
  130. def communication_backend_name(self):
  131. return self._communication_backend_name
  132. def is_triton_supported(self):
  133. return False
  134. # Graph operations
  135. def create_graph(self):
  136. return None
  137. def capture_to_graph(self, graph, pool=None, stream=None):
  138. from deepspeed.runtime.utils import noop_context
  139. return noop_context()
  140. def replay_graph(self, graph):
  141. return
  142. # Tensor operations
  143. @property
  144. def BFloat16Tensor(self):
  145. return torch.npu.BFloat16Tensor
  146. @property
  147. def ByteTensor(self):
  148. return torch.npu.ByteTensor
  149. @property
  150. def DoubleTensor(self):
  151. return torch.npu.DoubleTensor
  152. @property
  153. def FloatTensor(self):
  154. return torch.npu.FloatTensor
  155. @property
  156. def HalfTensor(self):
  157. return torch.npu.HalfTensor
  158. @property
  159. def IntTensor(self):
  160. return torch.npu.IntTensor
  161. @property
  162. def LongTensor(self):
  163. return torch.npu.LongTensor
  164. def pin_memory(self, tensor, align_bytes=1):
  165. return tensor.pin_memory()
  166. def is_pinned(self, tensor):
  167. return tensor.is_pinned()
  168. def on_accelerator(self, tensor):
  169. device_str = str(tensor.device)
  170. if device_str.startswith('npu:'):
  171. return True
  172. else:
  173. return False
  174. def op_builder_dir(self):
  175. try:
  176. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  177. # if successful this also means we're doing a local install and not JIT compile path
  178. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  179. return "op_builder.npu"
  180. except ImportError:
  181. return "deepspeed.ops.op_builder.npu"
  182. def _lazy_init_class_dict(self):
  183. if self.class_dict:
  184. return
  185. op_builder_module = importlib.import_module(self.op_builder_dir())
  186. # get op builder class from op_builder/npu/__init__.py
  187. self.class_dict = {}
  188. for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
  189. self.class_dict[class_name] = class_obj
  190. # create an instance of op builder and return, name specified by class_name
  191. def create_op_builder(self, class_name):
  192. builder_class = self.get_op_builder(class_name)
  193. return None if builder_class is None else builder_class()
  194. # return an op builder class, name specified by class_name
  195. def get_op_builder(self, class_name):
  196. self._lazy_init_class_dict()
  197. if class_name in self.class_dict:
  198. return self.class_dict[class_name]
  199. else:
  200. return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
  201. def build_extension(self):
  202. from torch.utils.cpp_extension import BuildExtension
  203. return BuildExtension
  204. def export_envs(self):
  205. return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']