cpu_accelerator.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
  6. import oneccl_bindings_for_pytorch # noqa: F401
  7. import psutil
  8. import os
  9. # accelerator for Intel CPU
  10. class CPU_Accelerator(DeepSpeedAccelerator):
  11. def __init__(self):
  12. self._name = 'cpu'
  13. self._communication_backend_name = 'ccl'
  14. self.max_mem = psutil.Process().memory_info().rss
  15. def is_synchronized_device(self):
  16. return True
  17. # Device APIs
  18. def device_name(self, device_index=None):
  19. return 'cpu'
  20. def device(self, device_index=None):
  21. return None
  22. def set_device(self, device_index):
  23. return
  24. def current_device(self):
  25. return os.environ.get('LOCAL_RANK', 0)
  26. def current_device_name(self):
  27. return 'cpu'
  28. def device_count(self):
  29. device_count = int(os.environ.get('LOCAL_SIZE', 0))
  30. if device_count > 0:
  31. return os.environ.get('LOCAL_SIZE')
  32. else:
  33. from deepspeed.utils.numa import get_numa_cores
  34. # Count NUMA node for number of cpu accelerators. On machine with HBM
  35. # In flat mode, HBM is in separate NUMA node with no cores on this node.
  36. # Ignore these NUMA nodes with no cores.
  37. numa_core_lists = get_numa_cores()
  38. numa_count = 0
  39. for core_list in numa_core_lists:
  40. if len(core_list) > 0:
  41. numa_count += 1
  42. return numa_count
  43. def synchronize(self, device_index=None):
  44. return
  45. # RNG APIs
  46. def random(self):
  47. return torch.random
  48. def set_rng_state(self, new_state, device_index=None):
  49. if device_index == None:
  50. return torch.set_rng_state(new_state)
  51. return torch.set_rng_state(new_state, device_index)
  52. def get_rng_state(self, device_index=None):
  53. return torch.get_rng_state()
  54. def manual_seed(self, seed):
  55. return torch.manual_seed(seed)
  56. def manual_seed_all(self, seed):
  57. return torch.manual_seed(seed)
  58. def initial_seed(self, seed):
  59. return torch.initial_seed(seed)
  60. def default_generator(self, device_index):
  61. return torch.default_generator
  62. # Streams/Events
  63. @property
  64. def Stream(self):
  65. return None
  66. def stream(self, stream):
  67. from deepspeed.runtime.utils import noop_decorator
  68. return noop_decorator
  69. def current_stream(self, device_index=None):
  70. return None
  71. def default_stream(self, device_index=None):
  72. return None
  73. @property
  74. def Event(self):
  75. return None
  76. # Memory management
  77. def empty_cache(self):
  78. return
  79. def get_rss(self):
  80. mem = psutil.Process().memory_info().rss
  81. if mem > self.max_mem:
  82. self.max_mem = mem
  83. return mem
  84. def reset_rss(self):
  85. mem = psutil.Process().memory_info().rss
  86. self.max_mem = mem
  87. return mem
  88. def memory_allocated(self, device_index=None):
  89. return self.get_rss()
  90. def max_memory_allocated(self, device_index=None):
  91. self.get_rss()
  92. return self.max_mem
  93. def reset_max_memory_allocated(self, device_index=None):
  94. self.reset_rss()
  95. return
  96. def memory_cached(self, device_index=None):
  97. return self.get_rss()
  98. def max_memory_cached(self, device_index=None):
  99. self.get_rss()
  100. return self.max_mem
  101. def reset_max_memory_cached(self, device_index=None):
  102. self.reset_rss()
  103. return
  104. def memory_stats(self, device_index=None):
  105. return self.get_rss()
  106. def reset_peak_memory_stats(self, device_index=None):
  107. self.reset_rss()
  108. return
  109. def memory_reserved(self, device_index=None):
  110. return self.get_rss()
  111. def max_memory_reserved(self, device_index=None):
  112. self.get_rss()
  113. return self.max_mem
  114. def total_memory(self, device_index=None):
  115. return psutil.virtual_memory().total
  116. # Misc
  117. def amp(self):
  118. return torch.cpu.amp
  119. def is_available(self):
  120. return True
  121. def range_push(self, msg):
  122. # TODO itt is currently not supported yet
  123. # return torch.profiler.itt.range_push(msg)
  124. return
  125. def range_pop(self):
  126. # TODO itt is currently not supported yet
  127. # return torch.profiler.itt.range_pop()
  128. return
  129. def lazy_call(self, callback):
  130. return callback()
  131. def communication_backend_name(self):
  132. return self._communication_backend_name
  133. # Data types
  134. def is_bf16_supported(self):
  135. return True
  136. def is_fp16_supported(self):
  137. return True
  138. # Tensor operations
  139. @property
  140. def BFloat16Tensor(self):
  141. return torch.BFloat16Tensor
  142. @property
  143. def ByteTensor(self):
  144. return torch.ByteTensor
  145. @property
  146. def DoubleTensor(self):
  147. return torch.DoubleTensor
  148. @property
  149. def FloatTensor(self):
  150. return torch.FloatTensor
  151. @property
  152. def HalfTensor(self):
  153. return torch.HalfTensor
  154. @property
  155. def IntTensor(self):
  156. return torch.IntTensor
  157. @property
  158. def LongTensor(self):
  159. return torch.LongTensor
  160. def pin_memory(self, tensor):
  161. return tensor
  162. def op_builder_dir(self):
  163. try:
  164. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  165. # if successful this also means we're doing a local install and not JIT compile path
  166. from op_builder import __deepspeed__ # noqa: F401
  167. return "op_builder.cpu"
  168. except ImportError:
  169. return "deepspeed.ops.op_builder.cpu"
  170. def on_accelerator(self, tensor):
  171. device_str = str(tensor.device)
  172. if device_str.startswith('cpu'):
  173. return True
  174. else:
  175. return False
  176. # create an instance of op builder and return, name specified by class_name
  177. def create_op_builder(self, op_name):
  178. builder_class = self.get_op_builder(op_name)
  179. if builder_class != None:
  180. return builder_class()
  181. return None
  182. # return an op builder class, name specified by class_name
  183. def get_op_builder(self, class_name):
  184. try:
  185. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  186. # if successful this also means we're doing a local install and not JIT compile path
  187. from op_builder import __deepspeed__ # noqa: F401
  188. from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
  189. except ImportError:
  190. from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
  191. if class_name == "CCLCommBuilder":
  192. return CCLCommBuilder
  193. else:
  194. # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
  195. return NotImplementedBuilder
  196. def build_extension(self):
  197. from torch.utils.cpp_extension import BuildExtension
  198. return BuildExtension