hpu_accelerator.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import pkgutil
  6. import importlib
  7. import torch
  8. from .abstract_accelerator import DeepSpeedAccelerator
  9. class HPU_Accelerator(DeepSpeedAccelerator):
  10. def __init__(self):
  11. self._name = 'hpu'
  12. self._communication_backend_name = 'hccl'
  13. try:
  14. import habana_frameworks.torch.hpu as hpu
  15. hpu.setDeterministic(True)
  16. self.hpu = hpu
  17. except ImportError as e:
  18. raise ValueError(
  19. f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
  20. self.fp16_supported = None
  21. # Device APIs
  22. def is_synchronized_device(self):
  23. return False
  24. def use_host_timers(self):
  25. return False
  26. def resolves_data_dependency(self):
  27. return True
  28. def handles_memory_backpressure(self):
  29. return True
  30. def device_name(self, device_index=None):
  31. if device_index is None:
  32. return 'hpu'
  33. return 'hpu:{}'.format(device_index)
  34. def device(self, device_index=None):
  35. return torch.device(self.device_name(device_index))
  36. def set_device(self, device_index):
  37. self.hpu.set_device(device_index)
  38. def current_device(self):
  39. return (self.hpu.current_device())
  40. def current_device_name(self):
  41. return 'hpu:{}'.format(self.current_device())
  42. def device_count(self):
  43. return self.hpu.device_count()
  44. def synchronize(self, device_index=None):
  45. return self.hpu.synchronize()
  46. # RNG APIs
  47. def random(self):
  48. return torch.random
  49. def set_rng_state(self, new_state, device_index=None):
  50. self.hpu.random.set_rng_state(new_state)
  51. def get_rng_state(self, device_index=None):
  52. return self.hpu.random.get_rng_state()
  53. def manual_seed(self, seed):
  54. self.hpu.random.manual_seed(seed)
  55. def manual_seed_all(self, seed):
  56. self.hpu.random.manual_seed_all(seed)
  57. def initial_seed(self, seed):
  58. self.hpu.random.initial_seed(seed)
  59. def default_generator(self, device_index):
  60. return self.hpu.random.default_generators[device_index]
  61. # Streams/Events
  62. @property
  63. def Stream(self):
  64. return self.hpu.Stream
  65. def stream(self, stream):
  66. return self.hpu.stream(stream)
  67. def current_stream(self, device_index=None):
  68. return self.hpu.current_stream()
  69. def default_stream(self, device_index=None):
  70. return self.hpu.default_stream()
  71. @property
  72. def Event(self):
  73. import habana_frameworks.torch.core as htcore
  74. return htcore.hpu.Event
  75. # Memory management
  76. def empty_cache(self):
  77. return
  78. def memory_allocated(self, device_index=None):
  79. return self.hpu.memory_allocated()
  80. def max_memory_allocated(self, device_index=None):
  81. return self.hpu.max_memory_allocated()
  82. def reset_max_memory_allocated(self, device_index=None):
  83. return self.hpu.reset_max_memory_allocated()
  84. def memory_cached(self, device_index=None):
  85. return self.hpu.memory_cached(device_index)
  86. def max_memory_cached(self, device_index=None):
  87. return self.hpu.max_memory_cached(device_index)
  88. def reset_max_memory_cached(self, device_index=None):
  89. return None
  90. def memory_stats(self, device_index=None):
  91. return self.hpu.memory_stats(device_index)
  92. def reset_peak_memory_stats(self, device_index=None):
  93. self.hpu.reset_peak_memory_stats(device_index)
  94. def memory_reserved(self, device_index=None):
  95. return self.hpu.memory_reserved(device_index)
  96. def max_memory_reserved(self, device_index=None):
  97. return self.hpu.max_memory_reserved(device_index)
  98. def total_memory(self, device_index=None):
  99. return self.memory_stats(device_index)['Limit']
  100. def available_memory(self, device_index=None):
  101. return self.total_memory(device_index) - self.memory_allocated(device_index)
  102. # Data types
  103. def is_bf16_supported(self):
  104. return True
  105. def is_fp16_supported(self):
  106. if self.fp16_supported is None:
  107. import habana_frameworks.torch.utils.experimental as htexp
  108. self.fp16_supported = htexp._is_fp16_supported()
  109. return self.fp16_supported
  110. def supported_dtypes(self):
  111. supported_dtypes = [torch.float, torch.bfloat16]
  112. if self.is_fp16_supported():
  113. supported_dtypes.append(torch.half)
  114. return supported_dtypes
  115. # Misc
  116. def amp(self):
  117. return None
  118. def is_available(self):
  119. return self.hpu.is_available()
  120. def range_push(self, msg):
  121. return
  122. def range_pop(self):
  123. return
  124. def lazy_call(self, callback):
  125. callback()
  126. def communication_backend_name(self):
  127. return self._communication_backend_name
  128. def is_triton_supported(self):
  129. return False
  130. # Graph operations
  131. def create_graph(self):
  132. return self.hpu.HPUGraph()
  133. def capture_to_graph(self, graph, pool=None, stream=None):
  134. return self.hpu.graph(graph, stream=stream)
  135. def replay_graph(self, graph):
  136. graph.replay()
  137. return
  138. # Tensor operations
  139. @property
  140. def BFloat16Tensor(self):
  141. return self.hpu.BFloat16Tensor
  142. @property
  143. def ByteTensor(self):
  144. return self.hpu.ByteTensor
  145. @property
  146. def DoubleTensor(self):
  147. return self.hpu.DoubleTensor
  148. @property
  149. def FloatTensor(self):
  150. return self.hpu.FloatTensor
  151. @property
  152. def HalfTensor(self):
  153. return self.hpu.HalfTensor
  154. @property
  155. def IntTensor(self):
  156. return self.hpu.IntTensor
  157. @property
  158. def LongTensor(self):
  159. return self.hpu.LongTensor
  160. def pin_memory(self, tensor, align_bytes=1):
  161. return tensor.pin_memory(self.device())
  162. def is_pinned(self, tensor):
  163. return tensor.is_pinned()
  164. def on_accelerator(self, tensor):
  165. device_str = str(tensor.device)
  166. if device_str.startswith('hpu:'):
  167. return True
  168. else:
  169. return False
  170. def op_builder_dir(self):
  171. try:
  172. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  173. # if successful this also means we're doing a local install and not JIT compile path
  174. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  175. return "op_builder.hpu"
  176. except ImportError:
  177. return "deepspeed.ops.op_builder.hpu"
  178. # dict that holds class name <--> class type mapping i.e.
  179. # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
  180. # this dict will be filled at init stage
  181. class_dict = None
  182. def _lazy_init_class_dict(self):
  183. if self.class_dict is not None:
  184. return
  185. else:
  186. self.class_dict = {}
  187. # begin initialize for create_op_builder()
  188. # put all valid class name <--> class type mapping into class_dict
  189. op_builder_dir = self.op_builder_dir()
  190. op_builder_module = importlib.import_module(op_builder_dir)
  191. op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
  192. for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
  193. # avoid self references,
  194. # skip sub_directories which contains ops for other backend(cpu, npu, etc.).
  195. if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
  196. os.path.join(op_builder_absolute_path, module_name)):
  197. module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
  198. for member_name in module.__dir__():
  199. if member_name.endswith(
  200. 'Builder'
  201. ) and member_name != "OpBuilder" and member_name != "CPUOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
  202. if not member_name in self.class_dict:
  203. self.class_dict[member_name] = getattr(module, member_name)
  204. # end initialize for create_op_builder()
  205. # create an instance of op builder and return, name specified by class_name
  206. def create_op_builder(self, class_name):
  207. self._lazy_init_class_dict()
  208. if class_name in self.class_dict:
  209. return self.class_dict[class_name]()
  210. else:
  211. return None
  212. # return an op builder class, name specified by class_name
  213. def get_op_builder(self, class_name):
  214. self._lazy_init_class_dict()
  215. if class_name in self.class_dict:
  216. return self.class_dict[class_name]
  217. else:
  218. return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
  219. def build_extension(self):
  220. from torch.utils.cpp_extension import BuildExtension
  221. return BuildExtension
  222. def export_envs(self):
  223. return []