utils.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/Megatron
  6. Helper functions and classes from multiple sources.
  7. """
  8. from collections.abc import Iterable
  9. from deepspeed.moe.utils import is_moe_param
  10. import os
  11. import psutil
  12. import gc
  13. from math import sqrt
  14. import torch
  15. from deepspeed import comm as dist
  16. try:
  17. from torch._six import inf
  18. except ModuleNotFoundError:
  19. from torch import inf
  20. from deepspeed.utils import groups, logger
  21. from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
  22. bwc_pipeline_parallel_group)
  23. from deepspeed.runtime.constants import PIPE_REPLICATED
  24. from numpy import prod
  25. from deepspeed.accelerator import get_accelerator
  26. from deepspeed.module_inject.policy import transpose
  27. from torch.nn import functional as F
  28. torch_memory_reserved = get_accelerator().memory_reserved
  29. torch_max_memory_reserved = get_accelerator().max_memory_reserved
  30. class DummyOptim():
  31. """
  32. Dummy optimizer presents model parameters as a param group, this is
  33. primarily used to allow ZeRO-3 without an optimizer
  34. """
  35. def __init__(self, params):
  36. self.param_groups = []
  37. self.param_groups.append({'params': params})
  38. graph_cache = {}
  39. def graph_process(replay_first_step, func, *args, **kwargs):
  40. # `func` should only contain operations on the GPU
  41. # Please ensure that the memory address of the data required by 'func' remains constant
  42. if func.__name__ not in graph_cache:
  43. cuda_stream = get_accelerator().Stream()
  44. cuda_stream.wait_stream(get_accelerator().current_stream())
  45. with get_accelerator().stream(cuda_stream):
  46. func(*args, **kwargs)
  47. get_accelerator().current_stream().wait_stream(cuda_stream)
  48. graph_cache[func.__name__] = get_accelerator().create_graph()
  49. with get_accelerator().capture_to_graph(graph_cache[func.__name__]):
  50. func(*args, **kwargs)
  51. if replay_first_step:
  52. get_accelerator().replay_graph(graph_cache[func.__name__])
  53. else:
  54. get_accelerator().replay_graph(graph_cache[func.__name__])
  55. def noop_decorator(func):
  56. return func
  57. class noop_context(object):
  58. def __init__(self):
  59. pass
  60. def __enter__(self):
  61. pass
  62. def __exit__(self, exc_type, exc_val, exc_tb):
  63. pass
  64. def ensure_directory_exists(filename):
  65. """Create the directory path to ``filename`` if it does not already exist.
  66. Args:
  67. filename (str): A file path.
  68. """
  69. dirname = os.path.dirname(filename)
  70. os.makedirs(dirname, exist_ok=True)
  71. def set_random_seed(seed):
  72. """Set the random seed for common PRNGs used during training: random, numpy, and torch.
  73. Args:
  74. seed (int): the seed to use
  75. """
  76. import numpy
  77. import random
  78. random.seed(seed)
  79. numpy.random.seed(seed)
  80. torch.manual_seed(seed)
  81. def is_model_parallel_parameter(p) -> bool:
  82. if hasattr(p, 'model_parallel') and p.model_parallel:
  83. return True
  84. if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel:
  85. return True
  86. return False
  87. def copy_to_device(item, device, criterion_func):
  88. """
  89. Return a copy of tensor on specified device.
  90. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  91. Parameters:
  92. item: tensor to copy or (possibly nested) container of tensors to copy.
  93. device: target device
  94. criterion_func: Function to restrict copy operation to items meet criterion
  95. Returns:
  96. None
  97. """
  98. if criterion_func(item):
  99. return item.to(device)
  100. elif isinstance(item, list):
  101. return [copy_to_device(v, device, criterion_func) for v in item]
  102. elif isinstance(item, tuple):
  103. return tuple([copy_to_device(v, device, criterion_func) for v in item])
  104. elif isinstance(item, dict):
  105. return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}
  106. else:
  107. return item
  108. def move_to_device(item, device, criterion_func):
  109. """
  110. Move tensor on to specified device by changing the storage.
  111. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  112. Parameters:
  113. item: tensor to move or (possibly nested) container of tensors to move.
  114. device: target device
  115. criterion_func: Function to restrict move operation to items meet criterion
  116. Returns:
  117. None
  118. """
  119. if criterion_func(item):
  120. device_copy = item.to(device)
  121. item.data = device_copy.data
  122. return item
  123. elif isinstance(item, list):
  124. return [move_to_device(v, device, criterion_func) for v in item]
  125. elif isinstance(item, tuple):
  126. return tuple([move_to_device(v, device, criterion_func) for v in item])
  127. elif isinstance(item, dict):
  128. return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
  129. else:
  130. return item
  131. def get_norm_with_moe_layers_fast(all_groups_norm, group):
  132. # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'.
  133. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
  134. scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group))
  135. scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float)
  136. dist.all_reduce(scaled_norm_tensor, group=group)
  137. all_groups_norm = scaled_norm_tensor.item()
  138. #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
  139. return all_groups_norm
  140. class CheckOverflow(object):
  141. '''Checks for overflow in gradient across parallel process'''
  142. def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None):
  143. self.mpu = mpu
  144. self.params = [] if param_groups else None
  145. self.zero_reduce_scatter = zero_reduce_scatter
  146. self.deepspeed = deepspeed
  147. self.has_moe_params = False
  148. if param_groups:
  149. for group in param_groups:
  150. for param in group:
  151. self.params.append(param)
  152. if is_moe_param(param):
  153. self.has_moe_params = True
  154. def check_using_norm(self, norm_group, reduce_overflow=True):
  155. # TODO: I don't think reduce_overflow is needed if mpu is None
  156. overflow = -1 in norm_group
  157. overflow_gpu = get_accelerator().FloatTensor([overflow])
  158. if self.has_moe_params:
  159. # In this case, we need to do an all_reduce across
  160. # the expert_parallel_group, so that if there was
  161. # an overflow due to expert weights, we detect it
  162. # Only need to check groups.get_largest_expert_parallel_group()
  163. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())
  164. if self.mpu is not None:
  165. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
  166. elif reduce_overflow:
  167. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)
  168. dist.barrier()
  169. overflow = overflow_gpu[0].item()
  170. return bool(overflow)
  171. def check(self, param_groups=None):
  172. params = []
  173. has_moe_params = False
  174. if param_groups is None:
  175. params = self.params
  176. has_moe_params = self.has_moe_params
  177. else:
  178. assert param_groups is not None, \
  179. "self.params and param_groups both cannot be none"
  180. for group in param_groups:
  181. for param in group:
  182. params.append(param)
  183. if is_moe_param(param):
  184. has_moe_params = True
  185. return self.has_overflow(params, has_moe_params=has_moe_params)
  186. # `params` is a list / generator of torch.Variable
  187. def has_overflow_serial(self, params):
  188. for i, p in enumerate(params):
  189. if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
  190. return True
  191. return False
  192. def has_overflow(self, params, has_moe_params=None):
  193. if has_moe_params is None:
  194. has_moe_params = self.has_moe_params
  195. overflow = self.has_overflow_serial(params)
  196. # Since each model parallel GPU carries only part of the model,
  197. # make sure overflow flag is synced across all the model parallel GPUs
  198. overflow_gpu = get_accelerator().ByteTensor([overflow])
  199. # deepspeed.comm.all_reduce(overflow_gpu,
  200. # op=deepspeed.comm.ReduceOp.MAX,
  201. # group=mpu.get_model_parallel_group())
  202. if has_moe_params:
  203. # All reduce this across expert_parallel_group, so that if an expert
  204. # overflows, we detect it here
  205. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())
  206. if self.zero_reduce_scatter:
  207. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())
  208. elif self.mpu is not None:
  209. if self.deepspeed is not None:
  210. using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
  211. if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or (
  212. not using_pipeline and self.deepspeed.enable_backward_allreduce is False):
  213. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group())
  214. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
  215. elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
  216. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())
  217. overflow = overflow_gpu[0].item()
  218. return bool(overflow)
  219. # `x` is a torch.Tensor
  220. @staticmethod
  221. def _has_inf_or_nan(x, i):
  222. try:
  223. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  224. # Pytorch's .sum() creates a one-element tensor of the same type as x
  225. # (which is true for some recent version of pytorch).
  226. cpu_sum = float(x.float().sum())
  227. # More efficient version that can be used if .sum() returns a Python scalar
  228. # cpu_sum = float(x.sum())
  229. except RuntimeError as instance:
  230. # We want to check if inst is actually an overflow exception.
  231. # RuntimeError could come from a different error.
  232. # If so, we still want the exception to propagate.
  233. if "value cannot be converted" not in instance.args[0]:
  234. raise
  235. return True
  236. else:
  237. if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
  238. return True
  239. return False
  240. def _handle_overflow(cpu_sum, x, i):
  241. import math
  242. rank = dist.get_rank()
  243. if rank == 0:
  244. t_i = -1
  245. for v_i, v in enumerate(x.data.contiguous().view(-1)):
  246. if not math.isfinite(float(v)):
  247. t_i = v_i
  248. break
  249. logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}")
  250. def get_global_norm(norm_list):
  251. """ Compute total from a list of norms
  252. """
  253. total_norm = 0.0
  254. for norm in norm_list:
  255. total_norm += norm**2.0
  256. # logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}')
  257. return sqrt(total_norm)
  258. def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
  259. """Clips gradient norm of an iterable of parameters.
  260. This has been adapted from Nvidia megatron. We add norm averaging
  261. to consider MoE params when calculating norm as they will result
  262. in different norms across different ranks.
  263. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  264. added functionality to handle model parallel parameters. Note that
  265. the gradients are modified in place.
  266. Arguments:
  267. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  268. single Tensor that will have gradients normalized
  269. max_norm (float or int): max norm of the gradients
  270. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  271. infinity norm.
  272. Returns:
  273. Total norm of the parameters (viewed as a single vector).
  274. """
  275. if isinstance(parameters, torch.Tensor):
  276. parameters = [parameters]
  277. parameters = list(filter(lambda p: p.grad is not None, parameters))
  278. norm_type = float(norm_type)
  279. all_norms = []
  280. if norm_type == inf:
  281. for p in parameters:
  282. all_norms.append(p.grad.data.abs().max().float())
  283. total_norm = torch.stack(all_norms).max()
  284. total_norm = total_norm.to(get_accelerator().current_device_name())
  285. # Take max across all GPUs.
  286. if mpu is not None:
  287. dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
  288. else:
  289. total_norm = 0
  290. for p in parameters:
  291. if mpu is not None:
  292. if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p):
  293. param_norm = p.grad.data.detach().float().norm(norm_type)
  294. all_norms.append(param_norm)
  295. else:
  296. param_norm = p.grad.data.detach().float().norm(norm_type)
  297. all_norms.append(param_norm)
  298. if len(all_norms) > 0:
  299. total_norm = torch.stack(all_norms).square().sum().float()
  300. else:
  301. total_norm = get_accelerator().FloatTensor([0.0])
  302. total_norm = total_norm.to(get_accelerator().current_device_name())
  303. # Sum across all model parallel GPUs.
  304. if mpu is not None:
  305. dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
  306. total_norm = total_norm.pow(1. / norm_type)
  307. # Need to average total_norm across different GPUs due to the presence of moe params
  308. pg = groups._get_data_parallel_group()
  309. scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))
  310. scaled_norm_tensor = scaled_norm
  311. dist.all_reduce(scaled_norm_tensor, group=pg)
  312. total_norm = scaled_norm_tensor
  313. total_norm = total_norm.to(parameters[0].device)
  314. max_norm = torch.tensor([float(max_norm)], device=total_norm.device)
  315. clip_coef = max_norm / (total_norm + 1e-6)
  316. tmp_tensor = torch.tensor([1.0], device=clip_coef.device)
  317. clip_coef = torch.min(tmp_tensor, clip_coef)
  318. for p in parameters:
  319. p.grad.data.mul_(clip_coef)
  320. return total_norm
  321. def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None):
  322. """Get grad norm of an iterable of parameters.
  323. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  324. added functionality to handle model parallel parameters. Note that
  325. the gradients are modified in place. Taken from Nvidia Megatron.
  326. Arguments:
  327. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  328. single Tensor that will have gradients normalized
  329. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  330. infinity norm.
  331. grad_norm_mask (List[Tensor]): A list of Tensor, where
  332. each Tensor is a 2D Tensor containing ranges of [start_index, end_index].
  333. Returns:
  334. Total norm of the parameters (viewed as a single vector).
  335. """
  336. if isinstance(parameters, torch.Tensor):
  337. parameters = [parameters]
  338. parameters = list(filter(lambda p: p.grad is not None, parameters))
  339. norm_type = float(norm_type)
  340. if norm_type == inf:
  341. total_norm = max(p.grad.data.abs().max() for p in parameters)
  342. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  343. # Take max across all GPUs.
  344. if mpu is not None:
  345. dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
  346. total_norm = total_norm_cuda[0].item()
  347. else:
  348. total_norm = 0.
  349. for idx, p in enumerate(parameters):
  350. # Use grad_norm_mask to avoid redundant computation of flattened gradient norm
  351. if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0:
  352. # A loop-free implementation to create a mask tensor based on a range list
  353. # which is logically equivalent to the following implementation.
  354. # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
  355. # # for mask_idx in grad_norm_mask[idx]:
  356. # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
  357. cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(),
  358. dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
  359. mask_tensor = torch.zeros(p.shape[0] + 1,
  360. device=get_accelerator().current_device_name(),
  361. dtype=p.dtype)
  362. mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
  363. cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]
  364. param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type)
  365. else:
  366. param_norm = p.grad.data.float().norm(norm_type)
  367. total_norm += param_norm.item()**norm_type
  368. # Sum across all model parallel GPUs.
  369. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  370. if mpu is not None:
  371. dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
  372. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  373. if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
  374. total_norm = -1
  375. return total_norm
  376. def get_grad_zeros(parameters, mpu=None):
  377. """Compute the number of grads with zero values.
  378. This is adapted from get_grad_norm
  379. Arguments:
  380. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  381. single Tensor that will have gradients normalized
  382. Returns:
  383. Total number of params with zero values (viewed as a single vector).
  384. """
  385. if isinstance(parameters, torch.Tensor):
  386. parameters = [parameters]
  387. parameters = list(filter(lambda p: p.grad is not None, parameters))
  388. total_zeros = 0.
  389. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  390. for p in parameters:
  391. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  392. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  393. continue
  394. # Filter to avoid over-counting replicated tensors from tensor
  395. # model parallelism
  396. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  397. continue
  398. count_zeros = p.grad.numel() - torch.count_nonzero(p.grad)
  399. total_zeros += count_zeros.item()
  400. # Sum across all model parallel GPUs.
  401. total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)])
  402. if mpu is not None:
  403. dist.all_reduce(total_zeros_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
  404. total_zeros = total_zeros_cuda[0].item()
  405. return total_zeros
  406. def get_weight_norm(parameters, norm_type=2, mpu=None):
  407. """Get norm of an iterable of parameters.
  408. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  409. added functionality to handle model parallel parameters. Note that
  410. the gradients are modified in place. Taken from Nvidia Megatron.
  411. Arguments:
  412. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  413. single Tensor that will have gradients normalized
  414. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  415. infinity norm.
  416. Returns:
  417. Total norm of the parameters (viewed as a single vector).
  418. -1 if the norm value is NaN or Inf.
  419. """
  420. if isinstance(parameters, torch.Tensor):
  421. parameters = [parameters]
  422. norm_type = float(norm_type)
  423. if norm_type == inf:
  424. total_norm = max(p.data.abs().max() for p in parameters)
  425. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  426. # Take max across all GPUs.
  427. if mpu is not None:
  428. dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
  429. total_norm = total_norm_cuda[0].item()
  430. else:
  431. total_norm = 0.
  432. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  433. for p in parameters:
  434. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  435. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  436. continue
  437. # Filter to avoid over-counting replicated tensors from tensor
  438. # model parallelism
  439. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  440. continue
  441. param_norm = p.data.float().norm(norm_type)
  442. total_norm += param_norm**norm_type
  443. # Sum across all model parallel GPUs.
  444. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  445. if mpu is not None:
  446. dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
  447. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  448. if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
  449. total_norm = -1
  450. return total_norm
  451. def prefix_sum_inc(weights):
  452. """ Compute an inclusive prefix sum.
  453. Example:
  454. >>> prefix_sum_inc([3,4,5])
  455. [3, 7, 12]
  456. """
  457. weights_ = [w for w in weights]
  458. for x in range(1, len(weights_)):
  459. weights_[x] += weights_[x - 1]
  460. return weights_
  461. def partition_uniform(num_items, num_parts):
  462. import numpy
  463. parts = [0] * (num_parts + 1)
  464. # First check for the trivial edge case
  465. if num_items <= num_parts:
  466. for p in range(num_parts + 1):
  467. parts[p] = min(p, num_items)
  468. return parts
  469. chunksize = num_items // num_parts
  470. residual = num_items - (chunksize * num_parts)
  471. parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize)
  472. for i in range(residual):
  473. parts[i + 1:] += 1
  474. parts = parts.tolist()
  475. return parts
  476. def partition_balanced(weights, num_parts):
  477. """
  478. use dynamic programming solve `The Linear Partition Problem`.
  479. see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM
  480. """
  481. import numpy as np
  482. n = len(weights)
  483. m = num_parts
  484. if n <= m:
  485. return partition_uniform(n, m)
  486. dp_max = np.full((n + 1, m + 1), np.inf)
  487. dp_min = np.full((n + 1, m + 1), np.inf)
  488. dp_cost = np.full((n + 1, m + 1), np.inf)
  489. position = np.zeros((n + 1, m + 1), dtype=int)
  490. prefix_sum = np.zeros((n + 1))
  491. prefix_sum[1:] = np.cumsum(weights)
  492. dp_max[0, 0] = 0
  493. dp_cost[0, 0] = 0
  494. for i in range(1, n + 1):
  495. for j in range(1, min(i, m) + 1):
  496. for k in range(i):
  497. max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k])
  498. min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k])
  499. cost = max_sum - min_sum
  500. if dp_cost[i, j] >= cost:
  501. dp_cost[i, j] = cost
  502. dp_max[i, j] = max_sum
  503. dp_min[i, j] = min_sum
  504. position[i, j] = k
  505. parts = [n]
  506. for i in reversed(range(1, m + 1)):
  507. parts.append(position[parts[-1], i])
  508. parts.reverse()
  509. return parts
  510. class PartitionedTensor:
  511. def __init__(self, tensor, group, partition_meta=None):
  512. super().__init__()
  513. self.group = group
  514. self.num_parts = dist.get_world_size(group=self.group)
  515. self.rank = dist.get_rank(group=self.group)
  516. self.orig_size = list(tensor.size())
  517. self.orig_device = tensor.device
  518. self.local_data, self.partition = self._partition_tensor(tensor)
  519. self.even_split = tensor.numel() % self.num_parts == 0
  520. @classmethod
  521. def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()):
  522. assert meta.dtype == torch.long
  523. dummy = torch.ones(dist.get_world_size(group=group))
  524. part_obj = cls(tensor=dummy, group=group)
  525. meta = meta.tolist()
  526. # [N, list0, ..., listN-1]
  527. part_obj.orig_size = meta[1:(1 + meta[0])]
  528. meta = meta[1 + meta[0]:]
  529. part_obj.orig_device = device
  530. part_obj.local_data = local_part.detach()
  531. part_obj.group = group
  532. # Partition is encoded like the rowptr of a CSR matrix:
  533. # [num_parts, rank, 0, part_1, ..., part_num_parts]
  534. # TODO: support shuffle between different partition granularities
  535. assert part_obj.num_parts == meta[0]
  536. assert part_obj.rank == meta[1]
  537. part_obj.partition = meta[2:] # length num_parts+1
  538. return part_obj
  539. def _partition_tensor(self, tensor):
  540. partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
  541. start = partition[self.rank]
  542. length = partition[self.rank + 1] - start
  543. tensor_part = tensor.detach().contiguous().view(-1).narrow(0, start=start, length=length).clone()
  544. return tensor_part, partition
  545. def full(self, device=None):
  546. if device is None:
  547. device = self.orig_device
  548. # Allocate the full tensor as a flat buffer.
  549. full_numel = prod(self.full_size())
  550. flat_tensor = torch.zeros([full_numel], dtype=self.local_data.dtype, device=device)
  551. if self.even_split:
  552. # Collect the full tensor
  553. dist.all_gather_into_tensor(flat_tensor, self.local_data, group=self.group)
  554. else:
  555. for part_id in range(self.num_parts):
  556. part_size = self.partition[part_id + 1] - self.partition[part_id]
  557. buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
  558. if part_id == self.rank:
  559. buf.copy_(self.local_data)
  560. dist.broadcast(buf, part_id, self.group)
  561. return flat_tensor.view(self.full_size()).clone().detach()
  562. def to_meta(self):
  563. """Returns a torch.LongTensor that encodes partitioning information.
  564. Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
  565. communication.
  566. Returns:
  567. torch.LongTensor: a tensor encoding the meta-information for the partitioning
  568. """
  569. meta = []
  570. meta.append(len(self.orig_size))
  571. meta += list(self.orig_size)
  572. meta.append(self.num_parts)
  573. meta.append(self.rank)
  574. meta += self.partition
  575. return torch.LongTensor(data=meta).to(self.orig_device)
  576. def data(self):
  577. return self.local_data
  578. def local_size(self):
  579. return self.local_data.size()
  580. def full_size(self):
  581. return self.orig_size
  582. mem_alloced = 0
  583. mem_cached = 0
  584. def memory_status(msg, print_rank=-1, reset_max=False):
  585. global mem_alloced, mem_cached
  586. rank = dist.get_rank()
  587. if print_rank != -1 and rank != print_rank:
  588. return
  589. get_accelerator().synchronize()
  590. if reset_max:
  591. get_accelerator().reset_max_memory_cached()
  592. get_accelerator().reset_max_memory_allocated()
  593. new_alloced = get_accelerator().memory_allocated()
  594. new_cached = get_accelerator().memory_cached()
  595. delta_alloced = new_alloced - mem_alloced
  596. delta_cached = new_cached - mem_cached
  597. mem_cached = new_cached
  598. mem_alloced = new_alloced
  599. max_alloced = get_accelerator().max_memory_allocated()
  600. max_cached = get_accelerator().max_memory_cached()
  601. # convert to GB for printing
  602. new_alloced /= 1024**3
  603. new_cached /= 1024**3
  604. delta_alloced /= 1024**3
  605. delta_cached /= 1024**3
  606. max_alloced /= 1024**3
  607. max_cached /= 1024**3
  608. print(
  609. f'RANK={rank} MEMSTATS', msg, f'device={get_accelerator().current_device_name()} '
  610. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  611. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
  612. def get_ma_status():
  613. if dist.is_initialized() and not dist.get_rank() == 0:
  614. return 0
  615. return get_accelerator().memory_allocated()
  616. def empty_cache():
  617. get_accelerator().empty_cache()
  618. get_accelerator().reset_peak_memory_stats()
  619. def see_memory_usage(message, force=False):
  620. if not force:
  621. return
  622. if dist.is_initialized() and not dist.get_rank() == 0:
  623. return
  624. # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
  625. gc.collect()
  626. # Print message except when distributed but not rank 0
  627. logger.info(message)
  628. logger.info(f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
  629. Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
  630. CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
  631. Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
  632. vm_stats = psutil.virtual_memory()
  633. used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
  634. logger.info(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
  635. # get the peak memory to report correct data, so reset the counter for the next call
  636. get_accelerator().reset_peak_memory_stats()
  637. def call_to_str(base, *args, **kwargs):
  638. """Construct a string representation of a call.
  639. Args:
  640. base (str): name of the call
  641. args (tuple, optional): args to ``base``
  642. kwargs (dict, optional): kwargs supplied to ``base``
  643. Returns:
  644. str: A string representation of base(*args, **kwargs)
  645. """
  646. name = f'{base}('
  647. if args:
  648. name += ', '.join(repr(arg) for arg in args)
  649. if kwargs:
  650. name += ', '
  651. if kwargs:
  652. name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
  653. name += ')'
  654. return name
  655. def get_only_unique_item(items):
  656. item_set = set(items)
  657. if len(item_set) != 1:
  658. raise RuntimeError(f"expected there to be only one unique element in {items}")
  659. unique_item, = item_set
  660. return unique_item
  661. def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
  662. """Get norm of an iterable of tensors.
  663. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  664. added functionality to handle model parallel parameters. Taken from Nvidia Megatron.
  665. Arguments:
  666. input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed
  667. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  668. infinity norm.
  669. Returns:
  670. Total norm of the tensors (viewed as a single vector).
  671. """
  672. assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}'
  673. assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'
  674. norm_type = float(norm_type)
  675. all_norms = []
  676. if norm_type == inf:
  677. for t in input_tensors:
  678. all_norms.append(t.data.abs().max().float())
  679. total_norm = torch.stack(all_norms).max()
  680. device_total_norm = total_norm.to(get_accelerator().current_device_name())
  681. # Max across model parallel
  682. if mpu is not None:
  683. # For MoE grads, max over model parallel only if MoE-TP is enabled
  684. if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
  685. dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
  686. # If MoE grads and MoE-TP disabled, max over pipeline parallel
  687. elif bwc_pipeline_parallel_world_size(mpu) > 1:
  688. dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu))
  689. # MoE grads: max across expert parallel group
  690. if moe_ep_group is not None:
  691. dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group)
  692. total_norm = device_total_norm.to(input_tensors[0].device)
  693. else:
  694. if 'norm_tensors_compute_buffer' not in graph_cache or len(
  695. graph_cache['norm_tensors_compute_buffer']) != len(input_tensors):
  696. graph_cache['norm_tensors_compute_buffer'] = [
  697. torch.empty([], dtype=torch.float, device=get_accelerator().current_device_name())
  698. for t in input_tensors
  699. ]
  700. compute_buffer = graph_cache['norm_tensors_compute_buffer']
  701. def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
  702. for i, t in enumerate(tensor_list):
  703. _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)
  704. if i != 0:
  705. _compute_buffer[0].data.add_(_compute_buffer[i].data)
  706. if use_graph:
  707. graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type)
  708. else:
  709. _norm_tensors(input_tensors, compute_buffer, norm_type)
  710. device_total_norm = compute_buffer[0].float().detach()
  711. # Sum across model parallel
  712. if mpu is not None:
  713. # For MoE grads, sum over model parallel only if MoE-TP is enabled
  714. if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
  715. dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
  716. # If MoE grads and MoE-TP disabled, sum over pipeline parallel
  717. elif bwc_pipeline_parallel_world_size(mpu) > 1:
  718. dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))
  719. # MoE grads: sum across expert parallel group
  720. if moe_ep_group is not None:
  721. dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
  722. total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)
  723. inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan())
  724. total_norm.masked_fill_(inf_or_nan, -1)
  725. return total_norm
  726. def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False):
  727. """Clip list of tensors by global norm.
  728. Args:
  729. input_tensors: List of tensors to be clipped
  730. global_norm (float, optional): Precomputed norm. Defaults to None.
  731. mpu (optional): model parallelism unit. Defaults to None.
  732. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
  733. Returns:
  734. float: the global norm
  735. """
  736. if global_norm is None:
  737. global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph)
  738. clip_coef = max_norm / (global_norm + eps)
  739. if clip_coef < 1:
  740. if use_graph:
  741. def clip_tensors(_tensor_list, _clip_coef_tensor):
  742. for t in _tensor_list:
  743. t.detach().mul_(_clip_coef_tensor)
  744. if 'clip_coef_tensor' not in graph_cache:
  745. # Alloc memory
  746. graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef,
  747. dtype=torch.float32).to(get_accelerator().device_name())
  748. clip_coef_tensor = graph_cache['clip_coef_tensor']
  749. clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32))
  750. graph_process(False, clip_tensors, input_tensors, clip_coef_tensor)
  751. else:
  752. for t in input_tensors:
  753. t.detach().mul_(clip_coef)
  754. return global_norm
  755. def align_dense_tensors(tensor_list, alignment):
  756. num_elements = sum(t.numel() for t in tensor_list)
  757. remaining = num_elements % alignment
  758. if remaining:
  759. elements_to_add = alignment - remaining
  760. pad_tensor = torch.zeros(elements_to_add, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
  761. padded_tensor_list = tensor_list + [pad_tensor]
  762. else:
  763. padded_tensor_list = tensor_list
  764. return padded_tensor_list
  765. def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, dp_process_group):
  766. for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)):
  767. partition_id = dist.get_rank(group=dp_process_group[group_id])
  768. dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
  769. if dp_world_size == 1:
  770. # no groups share optimizer states
  771. # pipeline parallel with bf16 will default call this even if dp size = 1.
  772. continue
  773. dist.all_gather_into_tensor(group_flat, partitioned_params[partition_id], dp_process_group[group_id])
  774. def all_gather_dp_groups(groups_flat, partitioned_param_groups, dp_process_group, start_alignment_factor,
  775. allgather_bucket_size):
  776. if dist.has_all_gather_into_tensor():
  777. return all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, dp_process_group)
  778. for group_id, partitioned_params in enumerate(partitioned_param_groups):
  779. # Sequential AllGather Best of both worlds
  780. partition_id = dist.get_rank(group=dp_process_group[group_id])
  781. dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
  782. if dp_world_size == 1:
  783. # no groups share optimizer states
  784. # pipeline parallel with bf16 will default call this even if dp size = 1.
  785. continue
  786. num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)
  787. shard_size = partitioned_params[partition_id].numel() // num_shards
  788. # Enforce nccl/rccl alignment of start location of each shard
  789. shard_size = shard_size - (shard_size % start_alignment_factor)
  790. num_elements = shard_size
  791. assert shard_size * num_shards <= partitioned_params[partition_id].numel()
  792. for shard_id in range(num_shards):
  793. if shard_id == (num_shards - 1):
  794. num_elements = partitioned_params[partition_id].numel() - shard_id * shard_size
  795. shard_list = []
  796. for dp_id in range(dp_world_size):
  797. curr_shard = partitioned_params[dp_id].narrow(0, shard_id * shard_size, num_elements).detach()
  798. shard_list.append(curr_shard)
  799. dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id])
  800. class TLinear(torch.nn.Linear):
  801. def __init__(self, orig_layer, name=""):
  802. self.name = name
  803. super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None))
  804. self.weight.data = transpose(orig_layer.weight.data)
  805. self.bias = orig_layer.bias
  806. self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd
  807. def _fwd(self, input):
  808. return F.linear(input, self.weight)
  809. def _fwd_bias_add(self, input):
  810. return F.linear(input, self.weight, bias=self.bias)
  811. def forward(self, input):
  812. return self._fwd_func(input)
  813. def get_inactive_params(param_list):
  814. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  815. return [param for param in param_list if (hasattr(param, 'ds_id') and \
  816. param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
  817. def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2):
  818. """ Compute the global norm with MoE experts
  819. Inputs:
  820. non_expert_norm (float) : the calculated norm of the non-expert params
  821. expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors
  822. norm_type (int): the norm to use
  823. Returns:
  824. if norm is (-/+) inf, returns -1
  825. otherwise the global norm (float)
  826. """
  827. def to_tensor(v):
  828. return get_accelerator().FloatTensor(float(v)).detach()
  829. group_norms = [non_expert_norm]
  830. for exp_name, tensors in expert_tensors.items():
  831. group_norm = get_global_norm_of_tensors(input_tensors=tensors,
  832. mpu=mpu,
  833. norm_type=norm_type,
  834. use_graph=False,
  835. moe_ep_group=groups._get_expert_parallel_group(exp_name))
  836. group_norms.append(group_norm)
  837. # check if all norms are valid
  838. group_norms = torch.stack([to_tensor(norm) for norm in group_norms])
  839. if group_norms.eq(-1).any():
  840. return -1
  841. # combine norms
  842. if norm_type == inf:
  843. total_norm = group_norms.max().item()
  844. else:
  845. total_norm = group_norms.pow(norm_type).sum()
  846. total_norm = total_norm.item()**(1. / norm_type)
  847. if total_norm == float('inf') or total_norm == -float('inf'):
  848. total_norm = -1
  849. return total_norm