hpu_accelerator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import functools
  5. import os
  6. import pkgutil
  7. import importlib
  8. import torch
  9. from .abstract_accelerator import DeepSpeedAccelerator
  10. class HPU_Accelerator(DeepSpeedAccelerator):
  11. def __init__(self):
  12. self._name = 'hpu'
  13. self._communication_backend_name = 'hccl'
  14. self._compile_backend = "hpu_backend"
  15. self.apply_hpu_workarounds()
  16. try:
  17. import habana_frameworks.torch.hpu as hpu
  18. hpu.setDeterministic(True)
  19. self.hpu = hpu
  20. except ImportError as e:
  21. raise ValueError(
  22. f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
  23. self.fp16_supported = None
  24. def apply_hpu_workarounds(self):
  25. def update_wa_env_var(key, value):
  26. if key not in os.environ.keys():
  27. os.environ[key] = value
  28. update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0")
  29. update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0")
  30. # Device APIs
  31. def is_synchronized_device(self):
  32. return False
  33. def use_host_timers(self):
  34. return False
  35. def resolves_data_dependency(self):
  36. return True
  37. def handles_memory_backpressure(self):
  38. return True
  39. def device_name(self, device_index=None):
  40. # ignoring device_index.
  41. return 'hpu'
  42. def device(self, device_index=None):
  43. return torch.device(self.device_name(device_index))
  44. def set_device(self, device_index):
  45. self.hpu.set_device(device_index)
  46. def current_device(self):
  47. return (self.hpu.current_device())
  48. def current_device_name(self):
  49. return 'hpu:{}'.format(self.current_device())
  50. def device_count(self):
  51. return self.hpu.device_count()
  52. def synchronize(self, device_index=None):
  53. return self.hpu.synchronize()
  54. # RNG APIs
  55. def random(self):
  56. return torch.random
  57. def set_rng_state(self, new_state, device_index=None):
  58. self.hpu.random.set_rng_state(new_state)
  59. def get_rng_state(self, device_index=None):
  60. return self.hpu.random.get_rng_state()
  61. def manual_seed(self, seed):
  62. return self.hpu.random.manual_seed(seed)
  63. def manual_seed_all(self, seed):
  64. self.hpu.random.manual_seed_all(seed)
  65. def initial_seed(self):
  66. return self.hpu.random.initial_seed()
  67. def default_generator(self, device_index):
  68. return self.hpu.random.default_generators[device_index]
  69. # Streams/Events
  70. @property
  71. def Stream(self):
  72. return self.hpu.Stream
  73. def stream(self, stream):
  74. return self.hpu.stream(stream)
  75. def current_stream(self, device_index=None):
  76. return self.hpu.current_stream()
  77. def default_stream(self, device_index=None):
  78. return self.hpu.default_stream()
  79. @property
  80. def Event(self):
  81. import habana_frameworks.torch.core as htcore
  82. return htcore.hpu.Event
  83. # Memory management
  84. def empty_cache(self):
  85. return
  86. def memory_allocated(self, device_index=None):
  87. return self.hpu.memory_allocated()
  88. def max_memory_allocated(self, device_index=None):
  89. return self.hpu.max_memory_allocated()
  90. def reset_max_memory_allocated(self, device_index=None):
  91. return self.hpu.reset_max_memory_allocated()
  92. def memory_cached(self, device_index=None):
  93. return self.hpu.memory_cached(device_index)
  94. def max_memory_cached(self, device_index=None):
  95. return self.hpu.max_memory_cached(device_index)
  96. def reset_max_memory_cached(self, device_index=None):
  97. return None
  98. def memory_stats(self, device_index=None):
  99. return self.hpu.memory_stats(device_index)
  100. def reset_peak_memory_stats(self, device_index=None):
  101. self.hpu.reset_peak_memory_stats(device_index)
  102. def memory_reserved(self, device_index=None):
  103. return self.hpu.memory_reserved(device_index)
  104. def max_memory_reserved(self, device_index=None):
  105. return self.hpu.max_memory_reserved(device_index)
  106. def total_memory(self, device_index=None):
  107. return self.memory_stats(device_index)['Limit']
  108. def available_memory(self, device_index=None):
  109. return self.total_memory(device_index) - self.memory_allocated(device_index)
  110. # Data types
  111. def is_bf16_supported(self):
  112. return True
  113. def is_fp16_supported(self):
  114. if self.fp16_supported is None:
  115. import habana_frameworks.torch.utils.experimental as htexp
  116. self.fp16_supported = htexp._is_fp16_supported()
  117. return self.fp16_supported
  118. def supported_dtypes(self):
  119. supported_dtypes = [torch.float, torch.bfloat16]
  120. if self.is_fp16_supported():
  121. supported_dtypes.append(torch.half)
  122. return supported_dtypes
  123. # Misc
  124. def amp(self):
  125. return None
  126. def is_available(self):
  127. return self.hpu.is_available()
  128. def range_push(self, msg):
  129. return
  130. def range_pop(self):
  131. return
  132. def lazy_call(self, callback):
  133. callback()
  134. def communication_backend_name(self):
  135. return self._communication_backend_name
  136. def is_triton_supported(self):
  137. return False
  138. # Graph operations
  139. def create_graph(self):
  140. return self.hpu.HPUGraph()
  141. def capture_to_graph(self, graph, pool=None, stream=None):
  142. return self.hpu.graph(graph, stream=stream)
  143. def replay_graph(self, graph):
  144. graph.replay()
  145. return
  146. # Tensor operations
  147. @property
  148. def BFloat16Tensor(self):
  149. return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')
  150. @property
  151. def ByteTensor(self):
  152. return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')
  153. @property
  154. def DoubleTensor(self):
  155. return functools.partial(torch.tensor, dtype=torch.double, device='hpu')
  156. @property
  157. def FloatTensor(self):
  158. return functools.partial(torch.tensor, dtype=torch.float, device='hpu')
  159. @property
  160. def HalfTensor(self):
  161. return functools.partial(torch.tensor, dtype=torch.half, device='hpu')
  162. @property
  163. def IntTensor(self):
  164. return functools.partial(torch.tensor, dtype=torch.int, device='hpu')
  165. @property
  166. def LongTensor(self):
  167. return functools.partial(torch.tensor, dtype=torch.long, device='hpu')
  168. def pin_memory(self, tensor, align_bytes=1):
  169. return tensor.pin_memory(self.device())
  170. def is_pinned(self, tensor):
  171. return tensor.is_pinned()
  172. def on_accelerator(self, tensor):
  173. device_str = str(tensor.device)
  174. if device_str.startswith('hpu:'):
  175. return True
  176. else:
  177. return False
  178. def op_builder_dir(self):
  179. try:
  180. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  181. # if successful this also means we're doing a local install and not JIT compile path
  182. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  183. return "op_builder.hpu"
  184. except ImportError:
  185. return "deepspeed.ops.op_builder.hpu"
  186. # dict that holds class name <--> class type mapping i.e.
  187. # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
  188. # this dict will be filled at init stage
  189. class_dict = None
  190. def _lazy_init_class_dict(self):
  191. if self.class_dict is not None:
  192. return
  193. else:
  194. self.class_dict = {}
  195. # begin initialize for create_op_builder()
  196. # put all valid class name <--> class type mapping into class_dict
  197. op_builder_dir = self.op_builder_dir()
  198. op_builder_module = importlib.import_module(op_builder_dir)
  199. op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
  200. for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
  201. # avoid self references,
  202. # skip sub_directories which contains ops for other backend(cpu, npu, etc.).
  203. if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
  204. os.path.join(op_builder_absolute_path, module_name)):
  205. module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
  206. for member_name in module.__dir__():
  207. if member_name.endswith(
  208. 'Builder'
  209. ) and member_name != "OpBuilder" and member_name != "CPUOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
  210. if not member_name in self.class_dict:
  211. self.class_dict[member_name] = getattr(module, member_name)
  212. # end initialize for create_op_builder()
  213. # create an instance of op builder and return, name specified by class_name
  214. def create_op_builder(self, class_name):
  215. self._lazy_init_class_dict()
  216. if class_name in self.class_dict:
  217. return self.class_dict[class_name]()
  218. else:
  219. return None
  220. # return an op builder class, name specified by class_name
  221. def get_op_builder(self, class_name):
  222. self._lazy_init_class_dict()
  223. if class_name in self.class_dict:
  224. return self.class_dict[class_name]
  225. else:
  226. return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
  227. def build_extension(self):
  228. from torch.utils.cpp_extension import BuildExtension
  229. return BuildExtension
  230. def export_envs(self):
  231. return []
  232. def visible_devices_envs(self):
  233. # Current way deepspeed set this env var is not applicable with all HPU instances
  234. # User has to follow instructions in:
  235. # https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html
  236. # keeping CUDA_VISIBLE_DEVICES
  237. return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES']
  238. def set_visible_devices_envs(self, current_env, local_accelerator_ids):
  239. for env in self.visible_devices_envs():
  240. current_env[env] = ",".join(map(str, local_accelerator_ids))
  241. def get_compile_backend(self):
  242. return self._compile_backend
  243. def set_compile_backend(self, backend):
  244. supported_backends = torch._dynamo.list_backends(exclude_tags=())
  245. if backend in supported_backends:
  246. self._compile_backend = backend
  247. else:
  248. raise ValueError(
  249. f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")