cuda_accelerator.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import functools
  5. import os
  6. import pkgutil
  7. import importlib
  8. import sys
  9. from .abstract_accelerator import DeepSpeedAccelerator
  10. # During setup stage torch may not be installed, pass on no torch will
  11. # allow op builder related API to be executed.
  12. try:
  13. import torch.cuda
  14. except ImportError:
  15. pass
  16. # Delay import pynvml to avoid import error when CUDA is not available
  17. pynvml = None
  18. class CUDA_Accelerator(DeepSpeedAccelerator):
  19. def __init__(self):
  20. self._name = 'cuda'
  21. self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
  22. self._compile_backend = "inductor"
  23. if pynvml is None:
  24. self._init_pynvml()
  25. def _init_pynvml(self):
  26. global pynvml
  27. try:
  28. import pynvml
  29. except ImportError:
  30. return
  31. try:
  32. pynvml.nvmlInit()
  33. except pynvml.NVMLError:
  34. pynvml = None
  35. return
  36. def is_synchronized_device(self):
  37. return False
  38. def use_host_timers(self):
  39. return self.is_synchronized_device()
  40. def resolves_data_dependency(self):
  41. return self.is_synchronized_device()
  42. def handles_memory_backpressure(self):
  43. return self.is_synchronized_device()
  44. # Device APIs
  45. def device_name(self, device_index=None):
  46. if device_index is None:
  47. return 'cuda'
  48. return 'cuda:{}'.format(device_index)
  49. def device(self, device_index=None):
  50. return torch.cuda.device(device_index)
  51. def set_device(self, device_index):
  52. torch.cuda.set_device(device_index)
  53. def current_device(self):
  54. return torch.cuda.current_device()
  55. def current_device_name(self):
  56. return 'cuda:{}'.format(torch.cuda.current_device())
  57. def device_count(self):
  58. return torch.cuda.device_count()
  59. def synchronize(self, device_index=None):
  60. return torch.cuda.synchronize(device_index)
  61. # RNG APIs
  62. def random(self):
  63. return torch.random
  64. def set_rng_state(self, new_state, device_index=None):
  65. if device_index is None:
  66. return torch.cuda.set_rng_state(new_state)
  67. return torch.cuda.set_rng_state(new_state, device_index)
  68. def get_rng_state(self, device_index=None):
  69. if device_index is None:
  70. return torch.cuda.get_rng_state()
  71. return torch.cuda.get_rng_state(device_index)
  72. def manual_seed(self, seed):
  73. return torch.cuda.manual_seed(seed)
  74. def manual_seed_all(self, seed):
  75. return torch.cuda.manual_seed_all(seed)
  76. def initial_seed(self):
  77. return torch.cuda.initial_seed()
  78. def default_generator(self, device_index):
  79. return torch.cuda.default_generators[device_index]
  80. # Streams/Events
  81. @property
  82. def Stream(self):
  83. return torch.cuda.Stream
  84. def stream(self, stream):
  85. return torch.cuda.stream(stream)
  86. def current_stream(self, device_index=None):
  87. return torch.cuda.current_stream(device_index)
  88. def default_stream(self, device_index=None):
  89. return torch.cuda.default_stream(device_index)
  90. @property
  91. def Event(self):
  92. return torch.cuda.Event
  93. # Memory management
  94. def empty_cache(self):
  95. return torch.cuda.empty_cache()
  96. def memory_allocated(self, device_index=None):
  97. return torch.cuda.memory_allocated(device_index)
  98. def max_memory_allocated(self, device_index=None):
  99. return torch.cuda.max_memory_allocated(device_index)
  100. def reset_max_memory_allocated(self, device_index=None):
  101. return torch.cuda.reset_max_memory_allocated(device_index)
  102. def memory_cached(self, device_index=None):
  103. return torch.cuda.memory_cached(device_index)
  104. def max_memory_cached(self, device_index=None):
  105. return torch.cuda.max_memory_cached(device_index)
  106. def reset_max_memory_cached(self, device_index=None):
  107. return torch.cuda.reset_max_memory_cached(device_index)
  108. def memory_stats(self, device_index=None):
  109. if hasattr(torch.cuda, 'memory_stats'):
  110. return torch.cuda.memory_stats(device_index)
  111. def reset_peak_memory_stats(self, device_index=None):
  112. if hasattr(torch.cuda, 'reset_peak_memory_stats'):
  113. return torch.cuda.reset_peak_memory_stats(device_index)
  114. def memory_reserved(self, device_index=None):
  115. if hasattr(torch.cuda, 'memory_reserved'):
  116. return torch.cuda.memory_reserved(device_index)
  117. def max_memory_reserved(self, device_index=None):
  118. if hasattr(torch.cuda, 'max_memory_reserved'):
  119. return torch.cuda.max_memory_reserved(device_index)
  120. def total_memory(self, device_index=None):
  121. return torch.cuda.get_device_properties(device_index).total_memory
  122. def _get_nvml_gpu_id(self, torch_gpu_id):
  123. """
  124. credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
  125. Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
  126. If the latter isn't set return the same id
  127. """
  128. # if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
  129. if "CUDA_VISIBLE_DEVICES" in os.environ:
  130. ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
  131. return ids[torch_gpu_id] # remap
  132. else:
  133. return torch_gpu_id
  134. def available_memory(self, device_index=None):
  135. if pynvml:
  136. if device_index is None:
  137. device_index = self.current_device()
  138. handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
  139. info = pynvml.nvmlDeviceGetMemoryInfo(handle)
  140. return info.free
  141. else:
  142. return self.total_memory(device_index) - self.memory_allocated(device_index)
  143. # Data types
  144. def is_bf16_supported(self):
  145. if not torch.cuda.is_available():
  146. return True
  147. return torch.cuda.is_bf16_supported()
  148. def is_fp16_supported(self):
  149. if not torch.cuda.is_available():
  150. return True
  151. # See https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix
  152. # FP16 on compute capability 6.x is deprecated
  153. allow_deprecated_fp16 = os.environ.get('DS_ALLOW_DEPRECATED_FP16', '0') == '1'
  154. major, _ = torch.cuda.get_device_capability()
  155. if major >= 7:
  156. return True
  157. elif major == 6 and allow_deprecated_fp16:
  158. return True
  159. else:
  160. return False
  161. def supported_dtypes(self):
  162. supported_dtypes = [torch.float]
  163. if self.is_fp16_supported():
  164. supported_dtypes.append(torch.half)
  165. if self.is_bf16_supported():
  166. supported_dtypes.append(torch.bfloat16)
  167. return supported_dtypes
  168. # Misc
  169. def amp(self):
  170. if hasattr(torch.cuda, 'amp'):
  171. return torch.cuda.amp
  172. return None
  173. def is_available(self):
  174. return torch.cuda.is_available()
  175. def range_push(self, msg):
  176. if hasattr(torch.cuda.nvtx, 'range_push'):
  177. return torch.cuda.nvtx.range_push(msg)
  178. def range_pop(self):
  179. if hasattr(torch.cuda.nvtx, 'range_pop'):
  180. return torch.cuda.nvtx.range_pop()
  181. def lazy_call(self, callback):
  182. return torch.cuda._lazy_call(callback)
  183. def communication_backend_name(self):
  184. return self._communication_backend_name
  185. def is_triton_supported(self):
  186. major, _ = torch.cuda.get_device_capability()
  187. if major >= 8:
  188. return True
  189. else:
  190. return False
  191. # Graph operations
  192. def create_graph(self):
  193. return torch.cuda.CUDAGraph()
  194. def capture_to_graph(self, graph, pool=None, stream=None):
  195. return torch.cuda.graph(graph, pool, stream)
  196. def replay_graph(self, graph):
  197. graph.replay()
  198. return
  199. # Tensor operations
  200. @property
  201. def BFloat16Tensor(self):
  202. return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')
  203. @property
  204. def ByteTensor(self):
  205. return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')
  206. @property
  207. def DoubleTensor(self):
  208. return functools.partial(torch.tensor, dtype=torch.double, device='cuda')
  209. @property
  210. def FloatTensor(self):
  211. return functools.partial(torch.tensor, dtype=torch.float, device='cuda')
  212. @property
  213. def HalfTensor(self):
  214. return functools.partial(torch.tensor, dtype=torch.half, device='cuda')
  215. @property
  216. def IntTensor(self):
  217. return functools.partial(torch.tensor, dtype=torch.int, device='cuda')
  218. @property
  219. def LongTensor(self):
  220. return functools.partial(torch.tensor, dtype=torch.long, device='cuda')
  221. def pin_memory(self, tensor, align_bytes=1):
  222. return tensor.pin_memory()
  223. def is_pinned(self, tensor):
  224. return tensor.is_pinned()
  225. def on_accelerator(self, tensor):
  226. device_str = str(tensor.device)
  227. if device_str.startswith('cuda:'):
  228. return True
  229. else:
  230. return False
  231. def op_builder_dir(self):
  232. try:
  233. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  234. # if successful this also means we're doing a local install and not JIT compile path
  235. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  236. return "op_builder"
  237. except ImportError:
  238. return "deepspeed.ops.op_builder"
  239. # dict that holds class name <--> class type mapping i.e.
  240. # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
  241. # this dict will be filled at init stage
  242. class_dict = None
  243. def _lazy_init_class_dict(self):
  244. if self.class_dict is not None:
  245. return
  246. else:
  247. self.class_dict = {}
  248. # begin initialize for create_op_builder()
  249. # put all valid class name <--> class type mapping into class_dict
  250. op_builder_dir = self.op_builder_dir()
  251. op_builder_module = importlib.import_module(op_builder_dir)
  252. op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
  253. for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
  254. # avoid self references,
  255. # skip sub_directories which contains ops for other backend(cpu, npu, etc.).
  256. if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
  257. os.path.join(op_builder_absolute_path, module_name)):
  258. module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
  259. for member_name in module.__dir__():
  260. if member_name.endswith(
  261. 'Builder'
  262. ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
  263. if not member_name in self.class_dict:
  264. self.class_dict[member_name] = getattr(module, member_name)
  265. # end initialize for create_op_builder()
  266. # create an instance of op builder and return, name specified by class_name
  267. def create_op_builder(self, class_name):
  268. self._lazy_init_class_dict()
  269. if class_name in self.class_dict:
  270. return self.class_dict[class_name]()
  271. else:
  272. return None
  273. # return an op builder class, name specified by class_name
  274. def get_op_builder(self, class_name):
  275. self._lazy_init_class_dict()
  276. if class_name in self.class_dict:
  277. return self.class_dict[class_name]
  278. else:
  279. return None
  280. def build_extension(self):
  281. from torch.utils.cpp_extension import BuildExtension
  282. return BuildExtension
  283. def export_envs(self):
  284. return ['NCCL']
  285. def visible_devices_envs(self):
  286. return ['CUDA_VISIBLE_DEVICES']
  287. def set_visible_devices_envs(self, current_env, local_accelerator_ids):
  288. for env in self.visible_devices_envs():
  289. current_env[env] = ",".join(map(str, local_accelerator_ids))
  290. def get_compile_backend(self):
  291. return self._compile_backend
  292. def set_compile_backend(self, backend):
  293. supported_backends = torch._dynamo.list_backends(exclude_tags=())
  294. if backend in supported_backends:
  295. self._compile_backend = backend
  296. else:
  297. raise ValueError(
  298. f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")