utils.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/Megatron
  4. Helper functions and classes from multiple sources.
  5. '''
  6. from collections.abc import Iterable
  7. from deepspeed.moe.utils import is_moe_param
  8. import os
  9. import psutil
  10. import gc
  11. from math import sqrt
  12. from math import floor
  13. from bisect import bisect_left
  14. import torch
  15. from deepspeed import comm as dist
  16. try:
  17. from torch._six import inf as inf
  18. except ModuleNotFoundError:
  19. from torch import inf as inf
  20. from deepspeed.utils import groups, logger
  21. from deepspeed.runtime.constants import PIPE_REPLICATED
  22. from numpy import prod
  23. from deepspeed.accelerator import get_accelerator
  24. torch_memory_reserved = get_accelerator().memory_reserved
  25. torch_max_memory_reserved = get_accelerator().max_memory_reserved
  26. class DummyOptim():
  27. """
  28. Dummy optimizer presents model parameters as a param group, this is
  29. primarily used to allow ZeRO-3 without an optimizer
  30. """
  31. def __init__(self, params):
  32. self.param_groups = []
  33. self.param_groups.append({'params': params})
  34. def noop_decorator(func):
  35. return func
  36. def ensure_directory_exists(filename):
  37. """Create the directory path to ``filename`` if it does not already exist.
  38. Args:
  39. filename (str): A file path.
  40. """
  41. dirname = os.path.dirname(filename)
  42. os.makedirs(dirname, exist_ok=True)
  43. def set_random_seed(seed):
  44. """Set the random seed for common PRNGs used during training: random, numpy, and torch.
  45. Args:
  46. seed (int): the seed to use
  47. """
  48. import numpy
  49. import random
  50. random.seed(seed)
  51. numpy.random.seed(seed)
  52. torch.manual_seed(seed)
  53. def is_model_parallel_parameter(p) -> bool:
  54. if hasattr(p, 'model_parallel') and p.model_parallel:
  55. return True
  56. if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel:
  57. return True
  58. return False
  59. def bwc_tensor_model_parallel_rank(mpu=None):
  60. """Backwards-compatible way of querying the tensor model parallel rank from
  61. an ``mpu`` object.
  62. *Tensor* model parallelism means that tensors are physically split across
  63. processes. This contrasts with *pipeline* model parallelism, in which the
  64. layers are partitioned but tensors left intact.
  65. The API for tensor model parallelism has changed across versions and this
  66. helper provides a best-effort implementation across versions of ``mpu``
  67. objects. The preferred mechanism is
  68. ``mpu.get_tensor_model_parallel_rank()``.
  69. This should "just work" with both Megatron-LM and DeepSpeed's pipeline
  70. parallelism.
  71. Args:
  72. mpu (model parallel unit, optional): The tensor model parallel rank.
  73. If ``mpu=None``, returns 0. Defaults to ``None``.
  74. Returns:
  75. int: the rank
  76. """
  77. if mpu is None:
  78. # No model parallelism in easy :)
  79. return 0
  80. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  81. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  82. return mpu.get_tensor_model_parallel_rank()
  83. elif hasattr(mpu, 'get_slice_parallel_rank'):
  84. # Some DeepSpeed + pipeline parallelism versions
  85. return mpu.get_slice_parallel_rank()
  86. else:
  87. # Deprecated Megatron and DeepSpeed convention
  88. return mpu.get_model_parallel_rank()
  89. def copy_to_device(item, device, criterion_func):
  90. """
  91. Return a copy of tensor on specified device.
  92. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  93. Parameters:
  94. item: tensor to copy or (possibly nested) container of tensors to copy.
  95. device: target device
  96. criterion_func: Function to restrict copy operation to items meet criterion
  97. Returns:
  98. None
  99. """
  100. if criterion_func(item):
  101. return item.to(device)
  102. elif isinstance(item, list):
  103. return [copy_to_device(v, device, criterion_func) for v in item]
  104. elif isinstance(item, tuple):
  105. return tuple([copy_to_device(v, device, criterion_func) for v in item])
  106. elif isinstance(item, dict):
  107. return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}
  108. else:
  109. return item
  110. def move_to_device(item, device, criterion_func):
  111. """
  112. Move tensor on to specified device by changing the storage.
  113. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  114. Parameters:
  115. item: tensor to move or (possibly nested) container of tensors to move.
  116. device: target device
  117. criterion_func: Function to restrict move operation to items meet criterion
  118. Returns:
  119. None
  120. """
  121. if criterion_func(item):
  122. device_copy = item.to(device)
  123. item.data = device_copy.data
  124. return item
  125. elif isinstance(item, list):
  126. return [move_to_device(v, device, criterion_func) for v in item]
  127. elif isinstance(item, tuple):
  128. return tuple([move_to_device(v, device, criterion_func) for v in item])
  129. elif isinstance(item, dict):
  130. return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
  131. else:
  132. return item
  133. class CheckOverflow(object):
  134. '''Checks for overflow in gradient across parallel process'''
  135. def __init__(self,
  136. param_groups=None,
  137. mpu=None,
  138. zero_reduce_scatter=False,
  139. deepspeed=None):
  140. self.mpu = mpu
  141. self.params = [] if param_groups else None
  142. self.zero_reduce_scatter = zero_reduce_scatter
  143. self.deepspeed = deepspeed
  144. self.has_moe_params = False
  145. if param_groups:
  146. for group in param_groups:
  147. for param in group:
  148. self.params.append(param)
  149. if is_moe_param(param):
  150. self.has_moe_params = True
  151. def check_using_norm(self, norm_group, reduce_overflow=True):
  152. # TODO: I don't think reduce_overflow is needed if mpu is None
  153. overflow = -1 in norm_group
  154. overflow_gpu = get_accelerator().FloatTensor([overflow])
  155. if self.has_moe_params:
  156. # In this case, we need to do an all_reduce across
  157. # the expert_parallel_group, so that if there was
  158. # an overflow due to expert weights, we detect it
  159. # Only need to check groups.get_largest_expert_parallel_group()
  160. dist.all_reduce(overflow_gpu,
  161. op=dist.ReduceOp.MAX,
  162. group=groups._get_max_expert_parallel_group())
  163. if self.mpu is not None:
  164. dist.all_reduce(overflow_gpu,
  165. op=dist.ReduceOp.MAX,
  166. group=self.mpu.get_model_parallel_group())
  167. elif reduce_overflow:
  168. dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)
  169. dist.barrier()
  170. overflow = overflow_gpu[0].item()
  171. return bool(overflow)
  172. def check(self, param_groups=None):
  173. params = []
  174. has_moe_params = False
  175. if param_groups is None:
  176. params = self.params
  177. has_moe_params = self.has_moe_params
  178. else:
  179. assert param_groups is not None, \
  180. "self.params and param_groups both cannot be none"
  181. for group in param_groups:
  182. for param in group:
  183. params.append(param)
  184. if is_moe_param(param):
  185. has_moe_params = True
  186. return self.has_overflow(params, has_moe_params=has_moe_params)
  187. # `params` is a list / generator of torch.Variable
  188. def has_overflow_serial(self, params):
  189. for i, p in enumerate(params):
  190. if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
  191. return True
  192. return False
  193. def has_overflow(self, params, has_moe_params=None):
  194. if has_moe_params is None:
  195. has_moe_params = self.has_moe_params
  196. overflow = self.has_overflow_serial(params)
  197. # Since each model parallel GPU carries only part of the model,
  198. # make sure overflow flag is synced across all the model parallel GPUs
  199. overflow_gpu = get_accelerator().ByteTensor([overflow])
  200. # deepspeeed.comm.all_reduce(overflow_gpu,
  201. # op=deepspeed.comm.ReduceOp.MAX,
  202. # group=mpu.get_model_parallel_group())
  203. if has_moe_params:
  204. # All reduce this across expert_parallel_group, so that if an expert
  205. # overflows, we detect it here
  206. dist.all_reduce(overflow_gpu,
  207. op=dist.ReduceOp.MAX,
  208. group=groups._get_max_expert_parallel_group())
  209. if self.zero_reduce_scatter:
  210. dist.all_reduce(overflow_gpu,
  211. op=dist.ReduceOp.MAX,
  212. group=dist.get_world_group())
  213. elif self.mpu is not None:
  214. if self.deepspeed is not None:
  215. using_pipeline = hasattr(self.deepspeed,
  216. 'pipeline_enable_backward_allreduce')
  217. if (using_pipeline
  218. and self.deepspeed.pipeline_enable_backward_allreduce is False
  219. ) or (not using_pipeline
  220. and self.deepspeed.enable_backward_allreduce is False):
  221. dist.all_reduce(overflow_gpu,
  222. op=dist.ReduceOp.MAX,
  223. group=self.mpu.get_data_parallel_group())
  224. dist.all_reduce(overflow_gpu,
  225. op=dist.ReduceOp.MAX,
  226. group=self.mpu.get_model_parallel_group())
  227. elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
  228. dist.all_reduce(overflow_gpu,
  229. op=dist.ReduceOp.MAX,
  230. group=dist.get_world_group())
  231. overflow = overflow_gpu[0].item()
  232. return bool(overflow)
  233. # `x` is a torch.Tensor
  234. @staticmethod
  235. def _has_inf_or_nan(x, i):
  236. try:
  237. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  238. # Pytorch's .sum() creates a one-element tensor of the same type as x
  239. # (which is true for some recent version of pytorch).
  240. cpu_sum = float(x.float().sum())
  241. # More efficient version that can be used if .sum() returns a Python scalar
  242. # cpu_sum = float(x.sum())
  243. except RuntimeError as instance:
  244. # We want to check if inst is actually an overflow exception.
  245. # RuntimeError could come from a different error.
  246. # If so, we still want the exception to propagate.
  247. if "value cannot be converted" not in instance.args[0]:
  248. raise
  249. return True
  250. else:
  251. if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
  252. return True
  253. return False
  254. def _handle_overflow(cpu_sum, x, i):
  255. import math
  256. rank = dist.get_rank()
  257. if rank == 0:
  258. t_i = -1
  259. for v_i, v in enumerate(x.data.contiguous().view(-1)):
  260. if not math.isfinite(float(v)):
  261. t_i = v_i
  262. break
  263. logger.info(
  264. f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
  265. )
  266. def get_global_norm(norm_list):
  267. """ Compute total from a list of norms
  268. """
  269. total_norm = 0.0
  270. for norm in norm_list:
  271. total_norm += norm**2.0
  272. return sqrt(total_norm)
  273. def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
  274. """Clips gradient norm of an iterable of parameters.
  275. This has been adapted from Nvidia megatron. We add norm averaging
  276. to consider MoE params when calculating norm as they will result
  277. in different norms across different ranks.
  278. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  279. added functionality to handle model parallel parameters. Note that
  280. the gradients are modified in place.
  281. Arguments:
  282. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  283. single Tensor that will have gradients normalized
  284. max_norm (float or int): max norm of the gradients
  285. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  286. infinity norm.
  287. Returns:
  288. Total norm of the parameters (viewed as a single vector).
  289. """
  290. if isinstance(parameters, torch.Tensor):
  291. parameters = [parameters]
  292. parameters = list(filter(lambda p: p.grad is not None, parameters))
  293. max_norm = float(max_norm)
  294. norm_type = float(norm_type)
  295. if norm_type == inf:
  296. total_norm = max(p.grad.data.abs().max() for p in parameters)
  297. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  298. # Take max across all GPUs.
  299. if mpu is not None:
  300. dist.all_reduce(total_norm_cuda,
  301. op=dist.ReduceOp.MAX,
  302. group=mpu.get_model_parallel_group())
  303. total_norm = total_norm_cuda[0].item()
  304. else:
  305. total_norm = 0
  306. for p in parameters:
  307. if mpu is not None:
  308. if (mpu.get_model_parallel_rank()
  309. == 0) or is_model_parallel_parameter(p):
  310. param_norm = p.grad.data.norm(norm_type)
  311. total_norm += param_norm.item()**norm_type
  312. else:
  313. param_norm = p.grad.data.float().norm(norm_type)
  314. total_norm += param_norm.item()**norm_type
  315. # Sum across all model parallel GPUs.
  316. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  317. if mpu is not None:
  318. dist.all_reduce(total_norm_cuda,
  319. op=dist.ReduceOp.SUM,
  320. group=mpu.get_model_parallel_group())
  321. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  322. # Need to average total_norm across different GPUs due to the presence of moe params
  323. pg = groups._get_data_parallel_group()
  324. scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))
  325. scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)])
  326. dist.all_reduce(scaled_norm_tensor, group=pg)
  327. total_norm = scaled_norm_tensor.item()
  328. clip_coef = max_norm / (total_norm + 1e-6)
  329. if clip_coef < 1:
  330. for p in parameters:
  331. p.grad.data.mul_(clip_coef)
  332. return total_norm
  333. def get_grad_norm(parameters, norm_type=2, mpu=None):
  334. """Get grad norm of an iterable of parameters.
  335. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  336. added functionality to handle model parallel parameters. Note that
  337. the gradients are modified in place. Taken from Nvidia Megatron.
  338. Arguments:
  339. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  340. single Tensor that will have gradients normalized
  341. max_norm (float or int): max norm of the gradients
  342. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  343. infinity norm.
  344. Returns:
  345. Total norm of the parameters (viewed as a single vector).
  346. """
  347. if isinstance(parameters, torch.Tensor):
  348. parameters = [parameters]
  349. parameters = list(filter(lambda p: p.grad is not None, parameters))
  350. norm_type = float(norm_type)
  351. if norm_type == inf:
  352. total_norm = max(p.grad.data.abs().max() for p in parameters)
  353. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  354. # Take max across all GPUs.
  355. if mpu is not None:
  356. dist.all_reduce(total_norm_cuda,
  357. op=dist.ReduceOp.MAX,
  358. group=mpu.get_model_parallel_group())
  359. total_norm = total_norm_cuda[0].item()
  360. else:
  361. total_norm = 0.
  362. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  363. for p in parameters:
  364. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  365. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  366. continue
  367. # Filter to avoid over-counting replicated tensors from tensor
  368. # model parallelism
  369. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  370. continue
  371. param_norm = p.grad.data.float().norm(norm_type)
  372. total_norm += param_norm.item()**norm_type
  373. # Sum across all model parallel GPUs.
  374. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  375. if mpu is not None:
  376. dist.all_reduce(total_norm_cuda,
  377. op=dist.ReduceOp.SUM,
  378. group=mpu.get_model_parallel_group())
  379. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  380. if total_norm == float(
  381. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  382. total_norm = -1
  383. return total_norm
  384. def get_grad_zeros(parameters, mpu=None):
  385. """Compute the number of grads with zero values.
  386. This is adapted from get_grad_norm
  387. Arguments:
  388. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  389. single Tensor that will have gradients normalized
  390. Returns:
  391. Total number of params with zero values (viewed as a single vector).
  392. """
  393. if isinstance(parameters, torch.Tensor):
  394. parameters = [parameters]
  395. parameters = list(filter(lambda p: p.grad is not None, parameters))
  396. total_zeros = 0.
  397. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  398. for p in parameters:
  399. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  400. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  401. continue
  402. # Filter to avoid over-counting replicated tensors from tensor
  403. # model parallelism
  404. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  405. continue
  406. count_zeros = p.grad.numel() - torch.count_nonzero(p.grad)
  407. total_zeros += count_zeros.item()
  408. # Sum across all model parallel GPUs.
  409. total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)])
  410. if mpu is not None:
  411. dist.all_reduce(total_zeros_cuda,
  412. op=dist.ReduceOp.SUM,
  413. group=mpu.get_model_parallel_group())
  414. total_zeros = total_zeros_cuda[0].item()
  415. return total_zeros
  416. def get_weight_norm(parameters, norm_type=2, mpu=None):
  417. """Get norm of an iterable of parameters.
  418. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  419. added functionality to handle model parallel parameters. Note that
  420. the gradients are modified in place. Taken from Nvidia Megatron.
  421. Arguments:
  422. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  423. single Tensor that will have gradients normalized
  424. max_norm (float or int): max norm of the gradients
  425. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  426. infinity norm.
  427. Returns:
  428. Total norm of the parameters (viewed as a single vector).
  429. """
  430. if isinstance(parameters, torch.Tensor):
  431. parameters = [parameters]
  432. norm_type = float(norm_type)
  433. if norm_type == inf:
  434. total_norm = max(p.data.abs().max() for p in parameters)
  435. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  436. # Take max across all GPUs.
  437. if mpu is not None:
  438. dist.all_reduce(total_norm_cuda,
  439. op=dist.ReduceOp.MAX,
  440. group=mpu.get_model_parallel_group())
  441. total_norm = total_norm_cuda[0].item()
  442. else:
  443. total_norm = 0.
  444. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  445. for p in parameters:
  446. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  447. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  448. continue
  449. # Filter to avoid over-counting replicated tensors from tensor
  450. # model parallelism
  451. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  452. continue
  453. param_norm = p.data.float().norm(norm_type)
  454. total_norm += param_norm**norm_type
  455. # Sum across all model parallel GPUs.
  456. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  457. if mpu is not None:
  458. dist.all_reduce(total_norm_cuda,
  459. op=dist.ReduceOp.SUM,
  460. group=mpu.get_model_parallel_group())
  461. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  462. if total_norm == float(
  463. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  464. total_norm = -1
  465. return total_norm
  466. def prefix_sum_inc(weights):
  467. """ Compute an inclusive prefix sum.
  468. Example:
  469. >>> prefix_sum_inc([3,4,5])
  470. [3, 7, 12]
  471. """
  472. weights_ = [w for w in weights]
  473. for x in range(1, len(weights_)):
  474. weights_[x] += weights_[x - 1]
  475. return weights_
  476. def partition_uniform(num_items, num_parts):
  477. parts = [0] * (num_parts + 1)
  478. # First check for the trivial edge case
  479. if num_items <= num_parts:
  480. for p in range(num_parts + 1):
  481. parts[p] = min(p, num_items)
  482. return parts
  483. chunksize = floor(num_items / num_parts)
  484. for p in range(num_parts):
  485. parts[p] = min(chunksize * p, num_items)
  486. parts[num_parts] = num_items
  487. return parts
  488. def _lprobe(weights, num_parts, bottleneck):
  489. num_items = len(weights)
  490. total_weight = weights[-1]
  491. # initialize partitioning
  492. parts = [0] * (num_parts + 1)
  493. for p in range(1, num_parts + 1):
  494. parts[p] = num_items
  495. bsum = bottleneck # running sum of target weight for pth partition
  496. chunksize = num_items // num_parts
  497. step = chunksize
  498. for p in range(1, num_parts):
  499. # Jump to the next bucket
  500. while (step < num_items) and (weights[step] < bsum):
  501. step += chunksize
  502. # Find the end index of partition p
  503. parts[p] = bisect_left(weights,
  504. bsum,
  505. lo=step - chunksize,
  506. hi=min(step,
  507. num_items))
  508. # Nothing more to partition, return early
  509. if parts[p] == num_items:
  510. # See if the current partition is overweight.
  511. part_size = weights[-1] - weights[parts[p - 1]]
  512. return parts, part_size < bottleneck
  513. # Next partition target
  514. bsum = weights[parts[p] - 1] + bottleneck
  515. return parts, bsum >= total_weight
  516. def _rb_partition_balanced(weights, num_parts, eps):
  517. total_weight = weights[-1]
  518. lower = total_weight / num_parts # best case heaviest partition
  519. upper = total_weight # worst case heaviest partition
  520. # Do a binary search for the best partitioning
  521. while upper > lower + eps:
  522. mid = lower + ((upper - lower) / 2)
  523. parts, success = _lprobe(weights, num_parts, mid)
  524. if success:
  525. upper = mid
  526. else:
  527. lower = mid + eps
  528. return upper
  529. def partition_balanced(weights, num_parts, eps=1e-3):
  530. num_items = len(weights)
  531. # First check for the trivial edge case
  532. if num_items <= num_parts:
  533. return partition_uniform(num_items, num_parts)
  534. weights_ = prefix_sum_inc(weights)
  535. # Find the smallest bottleneck (weight of heaviest partition)
  536. bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)
  537. # Now compute that partitioning
  538. parts, success = _lprobe(weights_, num_parts, bottleneck)
  539. assert success
  540. return parts
  541. class PartitionedTensor:
  542. def __init__(self, tensor, group, partition_meta=None):
  543. super().__init__()
  544. self.group = group
  545. self.num_parts = dist.get_world_size(group=self.group)
  546. self.rank = dist.get_rank(group=self.group)
  547. self.orig_size = list(tensor.size())
  548. self.orig_device = tensor.device
  549. self.local_data, self.partition = self._partition_tensor(tensor)
  550. @classmethod
  551. def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()):
  552. assert meta.dtype == torch.long
  553. dummy = torch.ones(dist.get_world_size(group=group))
  554. part_obj = cls(tensor=dummy, group=group)
  555. meta = meta.tolist()
  556. # [N, list0, ..., listN-1]
  557. part_obj.orig_size = meta[1:(1 + meta[0])]
  558. meta = meta[1 + meta[0]:]
  559. part_obj.orig_device = device
  560. part_obj.local_data = local_part.detach()
  561. part_obj.group = group
  562. # Partition is encoded like the rowptr of a CSR matrix:
  563. # [num_parts, rank, 0, part_1, ..., part_num_parts]
  564. # TODO: support shuffle between different partition granularities
  565. assert part_obj.num_parts == meta[0]
  566. assert part_obj.rank == meta[1]
  567. part_obj.partition = meta[2:] # length num_parts+1
  568. return part_obj
  569. def _partition_tensor(self, tensor):
  570. partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
  571. start = partition[self.rank]
  572. length = partition[self.rank + 1] - start
  573. tensor_part = tensor.detach().contiguous().view(-1).narrow(
  574. 0,
  575. start=start,
  576. length=length).clone()
  577. return tensor_part, partition
  578. def full(self, device=None):
  579. if device is None:
  580. device = self.orig_device
  581. # Allocate the full tensor as a flat buffer.
  582. full_numel = prod(self.full_size())
  583. flat_tensor = torch.zeros([full_numel],
  584. dtype=self.local_data.dtype,
  585. device=device)
  586. # Prepare all-gather buffer
  587. partition_tensors = []
  588. for part_id in range(self.num_parts):
  589. part_size = self.partition[part_id + 1] - self.partition[part_id]
  590. buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
  591. if part_id == self.rank:
  592. buf.copy_(self.local_data)
  593. partition_tensors.append(buf)
  594. # Collect the full tensor
  595. dist.all_gather(partition_tensors,
  596. partition_tensors[self.rank],
  597. group=self.group)
  598. for i in range(len(partition_tensors)):
  599. partition_tensors[i].data = torch.zeros(1)
  600. partition_tensors[i] = None
  601. return flat_tensor.view(self.full_size()).clone().detach()
  602. def to_meta(self):
  603. """Returns a torch.LongTensor that encodes partitioning information.
  604. Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
  605. communication.
  606. Returns:
  607. torch.LongTensor: a tensor encoding the meta-information for the partitioning
  608. """
  609. meta = []
  610. meta.append(len(self.orig_size))
  611. meta += list(self.orig_size)
  612. meta.append(self.num_parts)
  613. meta.append(self.rank)
  614. meta += self.partition
  615. return torch.LongTensor(data=meta).to(self.orig_device)
  616. def data(self):
  617. return self.local_data
  618. def local_size(self):
  619. return self.local_data.size()
  620. def full_size(self):
  621. return self.orig_size
  622. mem_alloced = 0
  623. mem_cached = 0
  624. def memory_status(msg, print_rank=-1, reset_max=False):
  625. global mem_alloced, mem_cached
  626. rank = dist.get_rank()
  627. if print_rank != -1 and rank != print_rank:
  628. return
  629. get_accelerator().synchronize()
  630. if reset_max:
  631. get_accelerator().reset_max_memory_cached()
  632. get_accelerator().reset_max_memory_allocated()
  633. new_alloced = get_accelerator().memory_allocated()
  634. new_cached = get_accelerator().memory_cached()
  635. delta_alloced = new_alloced - mem_alloced
  636. delta_cached = new_cached - mem_cached
  637. mem_cached = new_cached
  638. mem_alloced = new_alloced
  639. max_alloced = get_accelerator().max_memory_allocated()
  640. max_cached = get_accelerator().max_memory_cached()
  641. # convert to GB for printing
  642. new_alloced /= 1024**3
  643. new_cached /= 1024**3
  644. delta_alloced /= 1024**3
  645. delta_cached /= 1024**3
  646. max_alloced /= 1024**3
  647. max_cached /= 1024**3
  648. print(
  649. f'RANK={rank} MEMSTATS',
  650. msg,
  651. f'device={get_accelerator().current_device_name()} '
  652. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  653. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
  654. )
  655. def get_ma_status():
  656. if dist.is_initialized() and not dist.get_rank() == 0:
  657. return 0
  658. return get_accelerator().memory_allocated()
  659. def empty_cache():
  660. get_accelerator().empty_cache()
  661. def see_memory_usage(message, force=False):
  662. if not force:
  663. return
  664. if dist.is_initialized() and not dist.get_rank() == 0:
  665. return
  666. # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
  667. gc.collect()
  668. # Print message except when distributed but not rank 0
  669. logger.info(message)
  670. logger.info(
  671. f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
  672. Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
  673. CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
  674. Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
  675. vm_stats = psutil.virtual_memory()
  676. used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
  677. logger.info(
  678. f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
  679. # get the peak memory to report correct data, so reset the counter for the next call
  680. get_accelerator().reset_peak_memory_stats()
  681. def call_to_str(base, *args, **kwargs):
  682. """Construct a string representation of a call.
  683. Args:
  684. base (str): name of the call
  685. args (tuple, optional): args to ``base``
  686. kwargs (dict, optional): kwargs supplied to ``base``
  687. Returns:
  688. str: A string representation of base(*args, **kwargs)
  689. """
  690. name = f'{base}('
  691. if args:
  692. name += ', '.join(repr(arg) for arg in args)
  693. if kwargs:
  694. name += ', '
  695. if kwargs:
  696. name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
  697. name += ')'
  698. return name
  699. def get_only_unique_item(items):
  700. item_set = set(items)
  701. if len(item_set) != 1:
  702. raise RuntimeError(f"expected there to be only one unique element in {items}")
  703. unique_item, = item_set
  704. return unique_item
  705. def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6):
  706. """Clip the gradient of a list of parameters.
  707. Args:
  708. parameters: List of parameters whose .grad will be clipped.
  709. global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None.
  710. mpu (optional): model parallelism unit. Defaults to None.
  711. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
  712. Returns:
  713. float: the global gradient norm
  714. """
  715. if global_grad_norm is None:
  716. global_grad_norm = get_grad_norm(parameters, mpu=mpu)
  717. clip_coef = max_norm / (global_grad_norm + eps)
  718. if clip_coef < 1:
  719. for p in parameters:
  720. p.grad.detach().mul_(clip_coef)
  721. return global_grad_norm
  722. def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
  723. """Get norm of an iterable of tensors.
  724. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  725. added functionality to handle model parallel parameters. Taken from Nvidia Megatron.
  726. Arguments:
  727. input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed
  728. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  729. infinity norm.
  730. Returns:
  731. Total norm of the tensors (viewed as a single vector).
  732. """
  733. assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}'
  734. assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'
  735. norm_type = float(norm_type)
  736. if norm_type == inf:
  737. total_norm = max(t.data.abs().max() for t in input_tensors)
  738. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  739. if mpu is not None:
  740. dist.all_reduce(total_norm_cuda,
  741. op=dist.ReduceOp.MAX,
  742. group=mpu.get_model_parallel_group())
  743. total_norm = total_norm_cuda[0].item()
  744. else:
  745. total_norm = sum(
  746. [t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])
  747. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
  748. if mpu is not None:
  749. dist.all_reduce(total_norm_cuda,
  750. op=dist.ReduceOp.SUM,
  751. group=mpu.get_model_parallel_group())
  752. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  753. if total_norm == float(
  754. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  755. total_norm = -1
  756. return total_norm
  757. def clip_tensors_by_global_norm(input_tensors,
  758. max_norm=1.0,
  759. global_norm=None,
  760. mpu=None,
  761. eps=1e-6):
  762. """Clip list of tensors by global norm.
  763. Args:
  764. input_tensors: List of tensors to be clipped
  765. global_norm (float, optional): Precomputed norm. Defaults to None.
  766. mpu (optional): model parallelism unit. Defaults to None.
  767. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
  768. Returns:
  769. float: the global norm
  770. """
  771. if global_norm is None:
  772. global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu)
  773. clip_coef = max_norm / (global_norm + eps)
  774. if clip_coef < 1:
  775. for t in input_tensors:
  776. t.detach().mul_(clip_coef)
  777. return global_norm
  778. def align_dense_tensors(tensor_list, alignment):
  779. num_elements = sum(t.numel() for t in tensor_list)
  780. remaining = num_elements % alignment
  781. if remaining:
  782. elements_to_add = alignment - remaining
  783. pad_tensor = torch.zeros(elements_to_add,
  784. device=tensor_list[0].device,
  785. dtype=tensor_list[0].dtype)
  786. padded_tensor_list = tensor_list + [pad_tensor]
  787. else:
  788. padded_tensor_list = tensor_list
  789. return padded_tensor_list
  790. def all_gather_dp_groups(partitioned_param_groups,
  791. dp_process_group,
  792. start_alignment_factor,
  793. allgather_bucket_size):
  794. for group_id, partitioned_params in enumerate(partitioned_param_groups):
  795. # Sequential AllGather Best of both worlds
  796. partition_id = dist.get_rank(group=dp_process_group[group_id])
  797. dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
  798. num_shards = max(
  799. 1,
  800. partitioned_params[partition_id].numel() * dp_world_size //
  801. allgather_bucket_size)
  802. shard_size = partitioned_params[partition_id].numel() // num_shards
  803. # Enforce nccl/rccl alignment of start location of each shard
  804. shard_size = shard_size - (shard_size % start_alignment_factor)
  805. num_elements = shard_size
  806. assert shard_size * num_shards <= partitioned_params[partition_id].numel()
  807. for shard_id in range(num_shards):
  808. if shard_id == (num_shards - 1):
  809. num_elements = partitioned_params[partition_id].numel(
  810. ) - shard_id * shard_size
  811. shard_list = []
  812. for dp_id in range(dp_world_size):
  813. curr_shard = partitioned_params[dp_id].narrow(0,
  814. shard_id * shard_size,
  815. num_elements).detach()
  816. shard_list.append(curr_shard)
  817. dist.all_gather(shard_list,
  818. shard_list[partition_id],
  819. dp_process_group[group_id])