cpu_accelerator.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
  6. import oneccl_bindings_for_pytorch # noqa: F401
  7. import psutil
  8. import os
  9. # accelerator for Intel CPU
  10. class CPU_Accelerator(DeepSpeedAccelerator):
  11. def __init__(self):
  12. self._name = 'cpu'
  13. self._communication_backend_name = 'ccl'
  14. self.max_mem = psutil.Process().memory_info().rss
  15. def is_synchronized_device(self):
  16. return True
  17. # Device APIs
  18. def device_name(self, device_index=None):
  19. return 'cpu'
  20. def device(self, device_index=None):
  21. return None
  22. def set_device(self, device_index):
  23. return
  24. def current_device(self):
  25. return os.environ.get('LOCAL_RANK', 0)
  26. def current_device_name(self):
  27. return 'cpu'
  28. def device_count(self):
  29. device_count = int(os.environ.get('LOCAL_SIZE', 0))
  30. if device_count > 0:
  31. return device_count
  32. else:
  33. from deepspeed.utils.numa import get_numa_cores
  34. # Count NUMA node for number of cpu accelerators. On machine with HBM
  35. # In flat mode, HBM is in separate NUMA node with no cores on this node.
  36. # Ignore these NUMA nodes with no cores.
  37. numa_core_lists = get_numa_cores()
  38. numa_count = 0
  39. prev_core_list = []
  40. for core_list in numa_core_lists:
  41. if len(core_list) > 0 and core_list != prev_core_list:
  42. numa_count += 1
  43. prev_core_list = core_list
  44. return numa_count
  45. def synchronize(self, device_index=None):
  46. return
  47. # RNG APIs
  48. def random(self):
  49. return torch.random
  50. def set_rng_state(self, new_state, device_index=None):
  51. if device_index == None:
  52. return torch.set_rng_state(new_state)
  53. return torch.set_rng_state(new_state, device_index)
  54. def get_rng_state(self, device_index=None):
  55. return torch.get_rng_state()
  56. def manual_seed(self, seed):
  57. return torch.manual_seed(seed)
  58. def manual_seed_all(self, seed):
  59. return torch.manual_seed(seed)
  60. def initial_seed(self, seed):
  61. return torch.initial_seed(seed)
  62. def default_generator(self, device_index):
  63. return torch.default_generator
  64. # Streams/Events
  65. @property
  66. def Stream(self):
  67. return None
  68. def stream(self, stream):
  69. from deepspeed.runtime.utils import noop_decorator
  70. return noop_decorator
  71. def current_stream(self, device_index=None):
  72. return None
  73. def default_stream(self, device_index=None):
  74. return None
  75. @property
  76. def Event(self):
  77. return None
  78. # Memory management
  79. def empty_cache(self):
  80. return
  81. def get_rss(self):
  82. mem = psutil.Process().memory_info().rss
  83. if mem > self.max_mem:
  84. self.max_mem = mem
  85. return mem
  86. def reset_rss(self):
  87. mem = psutil.Process().memory_info().rss
  88. self.max_mem = mem
  89. return mem
  90. def memory_allocated(self, device_index=None):
  91. return self.get_rss()
  92. def max_memory_allocated(self, device_index=None):
  93. self.get_rss()
  94. return self.max_mem
  95. def reset_max_memory_allocated(self, device_index=None):
  96. self.reset_rss()
  97. return
  98. def memory_cached(self, device_index=None):
  99. return self.get_rss()
  100. def max_memory_cached(self, device_index=None):
  101. self.get_rss()
  102. return self.max_mem
  103. def reset_max_memory_cached(self, device_index=None):
  104. self.reset_rss()
  105. return
  106. def memory_stats(self, device_index=None):
  107. return self.get_rss()
  108. def reset_peak_memory_stats(self, device_index=None):
  109. self.reset_rss()
  110. return
  111. def memory_reserved(self, device_index=None):
  112. return self.get_rss()
  113. def max_memory_reserved(self, device_index=None):
  114. self.get_rss()
  115. return self.max_mem
  116. def total_memory(self, device_index=None):
  117. return psutil.virtual_memory().total
  118. # Misc
  119. def amp(self):
  120. return torch.cpu.amp
  121. def is_available(self):
  122. return True
  123. def range_push(self, msg):
  124. # TODO itt is currently not supported yet
  125. # return torch.profiler.itt.range_push(msg)
  126. return
  127. def range_pop(self):
  128. # TODO itt is currently not supported yet
  129. # return torch.profiler.itt.range_pop()
  130. return
  131. def lazy_call(self, callback):
  132. return callback()
  133. def communication_backend_name(self):
  134. return self._communication_backend_name
  135. # Data types
  136. def is_bf16_supported(self):
  137. return True
  138. def is_fp16_supported(self):
  139. return False
  140. def supported_dtypes(self):
  141. return [torch.float, torch.bfloat16]
  142. # Tensor operations
  143. @property
  144. def BFloat16Tensor(self):
  145. return torch.BFloat16Tensor
  146. @property
  147. def ByteTensor(self):
  148. return torch.ByteTensor
  149. @property
  150. def DoubleTensor(self):
  151. return torch.DoubleTensor
  152. @property
  153. def FloatTensor(self):
  154. return torch.FloatTensor
  155. @property
  156. def HalfTensor(self):
  157. return torch.HalfTensor
  158. @property
  159. def IntTensor(self):
  160. return torch.IntTensor
  161. @property
  162. def LongTensor(self):
  163. return torch.LongTensor
  164. def pin_memory(self, tensor):
  165. return tensor
  166. def op_builder_dir(self):
  167. try:
  168. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  169. # if successful this also means we're doing a local install and not JIT compile path
  170. from op_builder import __deepspeed__ # noqa: F401
  171. return "op_builder.cpu"
  172. except ImportError:
  173. return "deepspeed.ops.op_builder.cpu"
  174. def on_accelerator(self, tensor):
  175. device_str = str(tensor.device)
  176. if device_str.startswith('cpu'):
  177. return True
  178. else:
  179. return False
  180. # create an instance of op builder and return, name specified by class_name
  181. def create_op_builder(self, op_name):
  182. builder_class = self.get_op_builder(op_name)
  183. if builder_class != None:
  184. return builder_class()
  185. return None
  186. # return an op builder class, name specified by class_name
  187. def get_op_builder(self, class_name):
  188. try:
  189. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  190. # if successful this also means we're doing a local install and not JIT compile path
  191. from op_builder import __deepspeed__ # noqa: F401
  192. from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
  193. except ImportError:
  194. from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
  195. if class_name == "CCLCommBuilder":
  196. return CCLCommBuilder
  197. else:
  198. # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
  199. return NotImplementedBuilder
  200. def build_extension(self):
  201. from torch.utils.cpp_extension import BuildExtension
  202. return BuildExtension