xpu_accelerator.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
  6. import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
  7. import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
  8. class XPU_Accelerator(DeepSpeedAccelerator):
  9. def __init__(self):
  10. self._name = 'xpu'
  11. self._communication_backend_name = 'ccl'
  12. self.aligned_tensors = []
  13. def is_synchronized_device(self):
  14. return False
  15. def use_host_timers(self):
  16. return self.is_synchronized_device()
  17. def resolves_data_dependency(self):
  18. return self.is_synchronized_device()
  19. def handles_memory_backpressure(self):
  20. return self.is_synchronized_device()
  21. # Device APIs
  22. def device_name(self, device_index=None):
  23. if device_index == None:
  24. return 'xpu'
  25. return 'xpu:{}'.format(device_index)
  26. def device(self, device_index=None):
  27. return torch.xpu.device(device_index)
  28. def set_device(self, device_index):
  29. torch.xpu.set_device(device_index)
  30. def current_device(self):
  31. return torch.xpu.current_device()
  32. def current_device_name(self):
  33. return 'xpu:{}'.format(torch.xpu.current_device())
  34. def device_count(self):
  35. return torch.xpu.device_count()
  36. def synchronize(self, device_index=None):
  37. return torch.xpu.synchronize(device_index)
  38. # RNG APIs
  39. def random(self):
  40. return torch.xpu.random
  41. def set_rng_state(self, new_state, device_index=None):
  42. if device_index == None:
  43. return torch.xpu.set_rng_state(new_state)
  44. return torch.xpu.set_rng_state(new_state, device_index)
  45. def get_rng_state(self, device_index=None):
  46. if device_index == None:
  47. return torch.xpu.get_rng_state()
  48. return torch.xpu.get_rng_state(device_index)
  49. def manual_seed(self, seed):
  50. return torch.xpu.manual_seed(seed)
  51. def manual_seed_all(self, seed):
  52. return torch.xpu.manual_seed_all(seed)
  53. def initial_seed(self, seed):
  54. return torch.xpu.initial_seed(seed)
  55. def default_generator(self, device_index):
  56. return torch.xpu.default_generators[device_index]
  57. # Streams/Events
  58. @property
  59. def Stream(self):
  60. return torch.xpu.Stream
  61. def stream(self, stream):
  62. return torch.xpu.stream(stream)
  63. def current_stream(self, device_index=None):
  64. return torch.xpu.current_stream(device_index)
  65. def default_stream(self, device_index=None):
  66. # torch.xpu does not support the sync behavior of default stream as cuda
  67. # use current_stream as workaround
  68. # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams
  69. return torch.xpu.current_stream(device_index)
  70. @property
  71. def Event(self):
  72. return torch.xpu.Event
  73. # Memory management
  74. def empty_cache(self):
  75. return torch.xpu.empty_cache()
  76. def memory_allocated(self, device_index=None):
  77. return torch.xpu.memory_allocated(device_index)
  78. def max_memory_allocated(self, device_index=None):
  79. return torch.xpu.max_memory_allocated(device_index)
  80. def reset_max_memory_allocated(self, device_index=None):
  81. return torch.xpu.reset_max_memory_allocated(device_index)
  82. def memory_cached(self, device_index=None):
  83. return torch.xpu.memory_reserved(device_index)
  84. def max_memory_cached(self, device_index=None):
  85. return torch.xpu.max_memory_reserved(device_index)
  86. def reset_max_memory_cached(self, device_index=None):
  87. return torch.xpu.reset_max_memory_reserved(device_index)
  88. def memory_stats(self, device_index=None):
  89. return torch.xpu.memory_stats(device_index)
  90. def reset_peak_memory_stats(self, device_index=None):
  91. return torch.xpu.reset_peak_memory_stats(device_index)
  92. def memory_reserved(self, device_index=None):
  93. return torch.xpu.memory_reserved(device_index)
  94. def max_memory_reserved(self, device_index=None):
  95. return torch.xpu.max_memory_reserved(device_index)
  96. def total_memory(self, device_index=None):
  97. return torch.xpu.get_device_properties(device_index).total_memory
  98. def available_memory(self, device_index=None):
  99. return self.total_memory(device_index) - self.memory_allocated(device_index)
  100. # Misc
  101. def amp(self):
  102. return torch.xpu.amp
  103. def is_available(self):
  104. return torch.xpu.is_available()
  105. def range_push(self, msg):
  106. # TODO itt is currently not supported yet
  107. # return torch.profiler.itt.range_push(msg)
  108. return
  109. def range_pop(self):
  110. # TODO itt is currently not supported yet
  111. # return torch.profiler.itt.range_pop()
  112. return
  113. def lazy_call(self, callback):
  114. return torch.xpu.lazy_init._lazy_call(callback)
  115. def communication_backend_name(self):
  116. return self._communication_backend_name
  117. def is_triton_supported(self):
  118. return False
  119. # Graph operations
  120. def create_graph(self):
  121. return None
  122. def capture_to_graph(self, graph, pool=None, stream=None):
  123. from deepspeed.runtime.utils import noop_context
  124. return noop_context()
  125. def replay_graph(self, graph):
  126. return
  127. # Data types
  128. def is_bf16_supported(self):
  129. return True
  130. def is_fp16_supported(self):
  131. return True
  132. def supported_dtypes(self):
  133. return [torch.float, torch.half, torch.bfloat16]
  134. # Tensor operations
  135. @property
  136. def BFloat16Tensor(self):
  137. return torch.xpu.BFloat16Tensor
  138. @property
  139. def ByteTensor(self):
  140. return torch.xpu.ByteTensor
  141. @property
  142. def DoubleTensor(self):
  143. return torch.xpu.DoubleTensor
  144. @property
  145. def FloatTensor(self):
  146. return torch.xpu.FloatTensor
  147. @property
  148. def HalfTensor(self):
  149. return torch.xpu.HalfTensor
  150. @property
  151. def IntTensor(self):
  152. return torch.xpu.IntTensor
  153. @property
  154. def LongTensor(self):
  155. return torch.xpu.LongTensor
  156. def pin_memory(self, tensor, align_bytes=1):
  157. if align_bytes == 1:
  158. return tensor.pin_memory(device=self.current_device_name())
  159. elif align_bytes == 0:
  160. from intel_extension_for_deepspeed.op_builder.async_io import AsyncIOBuilder
  161. self.aio_handle = AsyncIOBuilder().load().aio_handle(128 * 1024, 8, False, False, False)
  162. aligned_t = self.aio_handle.new_cpu_locked_tensor(tensor.numel(), tensor)
  163. aligned_t = aligned_t[:tensor.numel()].copy_(tensor)
  164. self.aligned_tensors.append([aligned_t.data_ptr(), aligned_t[-1].data_ptr()])
  165. return aligned_t
  166. def is_pinned(self, tensor):
  167. if tensor.is_pinned(device=self.current_device_name()):
  168. return True
  169. else:
  170. for begin, end in self.aligned_tensors:
  171. if begin <= tensor.data_ptr() and tensor.data_ptr() <= end:
  172. return True
  173. return False
  174. def op_builder_dir(self):
  175. try:
  176. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  177. # if successful this also means we're doing a local install and not JIT compile path
  178. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  179. return "op_builder.xpu"
  180. except ImportError:
  181. return "deepspeed.ops.op_builder.xpu"
  182. def on_accelerator(self, tensor):
  183. device_str = str(tensor.device)
  184. if device_str.startswith('xpu:'):
  185. return True
  186. else:
  187. return False
  188. # create an instance of op builder and return, name specified by class_name
  189. def create_op_builder(self, op_name):
  190. builder_class = self.get_op_builder(op_name)
  191. if builder_class != None:
  192. return builder_class()
  193. return None
  194. # return an op builder class, name specified by class_name
  195. def get_op_builder(self, class_name):
  196. try:
  197. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  198. # if successful this also means we're doing a local install and not JIT compile path
  199. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  200. from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
  201. except ImportError:
  202. from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
  203. if class_name == "AsyncIOBuilder":
  204. return AsyncIOBuilder
  205. elif class_name == "CPUAdagradBuilder":
  206. return CPUAdagradBuilder
  207. elif class_name == "CPUAdamBuilder":
  208. return CPUAdamBuilder
  209. elif class_name == "FusedAdamBuilder":
  210. return FusedAdamBuilder
  211. else:
  212. return None
  213. def build_extension(self):
  214. try:
  215. from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
  216. except ImportError:
  217. from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
  218. return DpcppBuildExtension
  219. def export_envs(self):
  220. return []