cpu_accelerator.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from .abstract_accelerator import DeepSpeedAccelerator
  6. try:
  7. import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
  8. oneccl_imported_p = True
  9. except ImportError as e:
  10. oneccl_imported_p = False
  11. import os
  12. # accelerator for Intel CPU
  13. class CPU_Accelerator(DeepSpeedAccelerator):
  14. def __init__(self):
  15. self._name = 'cpu'
  16. if oneccl_imported_p:
  17. self._communication_backend_name = 'ccl'
  18. else:
  19. # fallback to gloo if oneccl_binding_for_pytorch is not installed
  20. self._communication_backend_name = 'gloo'
  21. try:
  22. import psutil
  23. mem = psutil.Process().memory_info().rss
  24. self.max_mem = mem
  25. except ImportError as e:
  26. self.max_mem = 0
  27. def is_synchronized_device(self):
  28. return True
  29. def use_host_timers(self):
  30. return self.is_synchronized_device()
  31. def resolves_data_dependency(self):
  32. return self.is_synchronized_device()
  33. def handles_memory_backpressure(self):
  34. return self.is_synchronized_device()
  35. # Device APIs
  36. def device_name(self, device_index=None):
  37. return 'cpu'
  38. def device(self, device_index=None):
  39. return None
  40. def set_device(self, device_index):
  41. return
  42. def current_device(self):
  43. return os.environ.get('LOCAL_RANK', 0)
  44. def current_device_name(self):
  45. return 'cpu'
  46. def device_count(self):
  47. device_count = int(os.environ.get('LOCAL_SIZE', 0))
  48. if device_count > 0:
  49. return device_count
  50. else:
  51. from deepspeed.utils.numa import get_numa_cores
  52. # Count NUMA node for number of cpu accelerators. On machine with HBM
  53. # In flat mode, HBM is in separate NUMA node with no cores on this node.
  54. # Ignore these NUMA nodes with no cores.
  55. numa_core_lists = get_numa_cores()
  56. numa_count = 0
  57. prev_core_list = []
  58. for core_list in numa_core_lists:
  59. if len(core_list) > 0 and core_list != prev_core_list:
  60. numa_count += 1
  61. prev_core_list = core_list
  62. return numa_count
  63. def synchronize(self, device_index=None):
  64. return
  65. # RNG APIs
  66. def random(self):
  67. return torch.random
  68. def set_rng_state(self, new_state, device_index=None):
  69. if device_index is None:
  70. return torch.set_rng_state(new_state)
  71. return torch.set_rng_state(new_state, device_index)
  72. def get_rng_state(self, device_index=None):
  73. return torch.get_rng_state()
  74. def manual_seed(self, seed):
  75. return torch.manual_seed(seed)
  76. def manual_seed_all(self, seed):
  77. return torch.manual_seed(seed)
  78. def initial_seed(self, seed):
  79. return torch.initial_seed(seed)
  80. def default_generator(self, device_index):
  81. return torch.default_generator
  82. # Streams/Events
  83. @property
  84. def Stream(self):
  85. return None
  86. def stream(self, stream):
  87. from deepspeed.runtime.utils import noop_context
  88. return noop_context()
  89. def current_stream(self, device_index=None):
  90. return None
  91. def default_stream(self, device_index=None):
  92. return None
  93. @property
  94. def Event(self):
  95. return None
  96. # Memory management
  97. def empty_cache(self):
  98. return
  99. def get_rss(self):
  100. import psutil
  101. mem = psutil.Process().memory_info().rss
  102. if mem > self.max_mem:
  103. self.max_mem = mem
  104. return mem
  105. def reset_rss(self):
  106. import psutil
  107. mem = psutil.Process().memory_info().rss
  108. self.max_mem = mem
  109. return mem
  110. def memory_allocated(self, device_index=None):
  111. return self.get_rss()
  112. def max_memory_allocated(self, device_index=None):
  113. self.get_rss()
  114. return self.max_mem
  115. def reset_max_memory_allocated(self, device_index=None):
  116. self.reset_rss()
  117. return
  118. def memory_cached(self, device_index=None):
  119. return self.get_rss()
  120. def max_memory_cached(self, device_index=None):
  121. self.get_rss()
  122. return self.max_mem
  123. def reset_max_memory_cached(self, device_index=None):
  124. self.reset_rss()
  125. return
  126. def memory_stats(self, device_index=None):
  127. mem = self.get_rss()
  128. mem_stat = {}
  129. mem_stat['allocated_bytes.all.current'] = mem
  130. mem_stat['allocated_bytes.all.peak'] = self.max_mem
  131. return mem_stat
  132. def reset_peak_memory_stats(self, device_index=None):
  133. self.reset_rss()
  134. return
  135. def memory_reserved(self, device_index=None):
  136. return self.get_rss()
  137. def max_memory_reserved(self, device_index=None):
  138. self.get_rss()
  139. return self.max_mem
  140. def total_memory(self, device_index=None):
  141. import psutil
  142. return psutil.virtual_memory().total
  143. def available_memory(self, device_index=None):
  144. import psutil
  145. return psutil.virtual_memory().available
  146. # Misc
  147. def amp(self):
  148. return torch.cpu.amp
  149. def is_available(self):
  150. return True
  151. def range_push(self, msg):
  152. # TODO itt is currently not supported yet
  153. # return torch.profiler.itt.range_push(msg)
  154. return
  155. def range_pop(self):
  156. # TODO itt is currently not supported yet
  157. # return torch.profiler.itt.range_pop()
  158. return
  159. def lazy_call(self, callback):
  160. return callback()
  161. def communication_backend_name(self):
  162. return self._communication_backend_name
  163. def is_triton_supported(self):
  164. return False
  165. # Data types
  166. def is_bf16_supported(self):
  167. return True
  168. def is_fp16_supported(self):
  169. return False
  170. def supported_dtypes(self):
  171. return [torch.float, torch.bfloat16]
  172. # Graph operations
  173. def create_graph(self):
  174. return None
  175. def capture_to_graph(self, graph, pool=None, stream=None):
  176. from deepspeed.runtime.utils import noop_context
  177. return noop_context()
  178. def replay_graph(self, graph):
  179. return
  180. # Tensor operations
  181. @property
  182. def BFloat16Tensor(self):
  183. return torch.BFloat16Tensor
  184. @property
  185. def ByteTensor(self):
  186. return torch.ByteTensor
  187. @property
  188. def DoubleTensor(self):
  189. return torch.DoubleTensor
  190. @property
  191. def FloatTensor(self):
  192. return torch.FloatTensor
  193. @property
  194. def HalfTensor(self):
  195. return torch.HalfTensor
  196. @property
  197. def IntTensor(self):
  198. return torch.IntTensor
  199. @property
  200. def LongTensor(self):
  201. return torch.LongTensor
  202. def pin_memory(self, tensor, align_bytes=1):
  203. return tensor
  204. def is_pinned(self, tensor):
  205. return tensor.is_pinned()
  206. def op_builder_dir(self):
  207. try:
  208. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  209. # if successful this also means we're doing a local install and not JIT compile path
  210. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  211. return "op_builder.cpu"
  212. except ImportError:
  213. return "deepspeed.ops.op_builder.cpu"
  214. def on_accelerator(self, tensor):
  215. device_str = str(tensor.device)
  216. if device_str.startswith('cpu'):
  217. return True
  218. else:
  219. return False
  220. # create an instance of op builder and return, name specified by class_name
  221. def create_op_builder(self, op_name):
  222. builder_class = self.get_op_builder(op_name)
  223. if builder_class is not None:
  224. return builder_class()
  225. return None
  226. # return an op builder class, name specified by class_name
  227. def get_op_builder(self, class_name):
  228. try:
  229. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  230. # if successful this also means we're doing a local install and not JIT compile path
  231. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  232. from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
  233. except ImportError:
  234. from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
  235. if class_name == "CCLCommBuilder":
  236. return CCLCommBuilder
  237. elif class_name == "FusedAdamBuilder":
  238. return FusedAdamBuilder
  239. elif class_name == "CPUAdamBuilder":
  240. return CPUAdamBuilder
  241. else:
  242. # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
  243. return NotImplementedBuilder
  244. def build_extension(self):
  245. from torch.utils.cpp_extension import BuildExtension
  246. return BuildExtension
  247. def export_envs(self):
  248. return []