utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/Megatron
  4. Helper functions and classes from multiple sources.
  5. '''
  6. from deepspeed.moe.utils import is_moe_param, split_params_into_shared_and_expert_params
  7. import os
  8. import psutil
  9. import gc
  10. from math import ceil, sqrt
  11. from math import floor
  12. from bisect import bisect_left, bisect_right
  13. import torch
  14. from torch._six import inf
  15. import torch.distributed as dist
  16. from deepspeed.utils import groups, logger
  17. from numpy import prod
  18. # pt-1.9 deprecations
  19. if hasattr(torch.cuda, "memory_reserved"):
  20. torch_memory_reserved = torch.cuda.memory_reserved
  21. else:
  22. torch_memory_reserved = torch.cuda.memory_allocated
  23. if hasattr(torch.cuda, "max_memory_reserved"):
  24. torch_max_memory_reserved = torch.cuda.max_memory_reserved
  25. else:
  26. torch_max_memory_reserved = torch.cuda.memory_cached
  27. class DummyOptim():
  28. """
  29. Dummy optimizer presents model parameters as a param group, this is
  30. primarily used to allow ZeRO-3 without an optimizer
  31. """
  32. def __init__(self, params):
  33. self.param_groups = []
  34. self.param_groups.append({'params': params})
  35. def noop_decorator(func):
  36. return func
  37. def ensure_directory_exists(filename):
  38. """Create the directory path to ``filename`` if it does not already exist.
  39. Args:
  40. filename (str): A file path.
  41. """
  42. dirname = os.path.dirname(filename)
  43. os.makedirs(dirname, exist_ok=True)
  44. def set_random_seed(seed):
  45. """Set the random seed for common PRNGs used during training: random, numpy, and torch.
  46. Args:
  47. seed (int): the seed to use
  48. """
  49. import numpy
  50. import random
  51. random.seed(seed)
  52. numpy.random.seed(seed)
  53. torch.manual_seed(seed)
  54. def is_model_parallel_parameter(p) -> bool:
  55. return hasattr(p, 'model_parallel') and p.model_parallel
  56. def bwc_tensor_model_parallel_rank(mpu=None):
  57. """Backwards-compatible way of querying the tensor model parallel rank from
  58. an ``mpu`` object.
  59. *Tensor* model parallelism means that tensors are physically split across
  60. processes. This contrasts with *pipeline* model parallelism, in which the
  61. layers are partitioned but tensors left intact.
  62. The API for tensor model parallelism has changed across versions and this
  63. helper provides a best-effort implementation across versions of ``mpu``
  64. objects. The preferred mechanism is
  65. ``mpu.get_tensor_model_parallel_rank()``.
  66. This should "just work" with both Megatron-LM and DeepSpeed's pipeline
  67. parallelism.
  68. Args:
  69. mpu (model parallel unit, optional): The tensor model parallel rank.
  70. If ``mpu=None``, returns 0. Defaults to ``None``.
  71. Returns:
  72. int: the rank
  73. """
  74. if mpu is None:
  75. # No model parallelism in easy :)
  76. return 0
  77. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  78. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  79. return mpu.get_tensor_model_parallel_rank()
  80. elif hasattr(mpu, 'get_slice_parallel_rank'):
  81. # Some DeepSpeed + pipeline parallelism versions
  82. return mpu.get_slice_parallel_rank()
  83. else:
  84. # Deprecated Megatron and DeepSpeed convention
  85. return mpu.get_model_parallel_rank()
  86. def copy_to_device(item, device, criterion_func):
  87. """
  88. Return a copy of tensor on specified device.
  89. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  90. Parameters:
  91. item: tensor to copy or (possibly nested) container of tensors to copy.
  92. device: target device
  93. criterion_func: Function to restrict copy operation to items meet criterion
  94. Returns:
  95. None
  96. """
  97. if criterion_func(item):
  98. return item.to(device)
  99. elif isinstance(item, list):
  100. return [copy_to_device(v, device, criterion_func) for v in item]
  101. elif isinstance(item, tuple):
  102. return tuple([copy_to_device(v, device, criterion_func) for v in item])
  103. elif isinstance(item, dict):
  104. return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}
  105. else:
  106. return item
  107. def move_to_device(item, device, criterion_func):
  108. """
  109. Move tensor on to specified device by changing the storage.
  110. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  111. Parameters:
  112. item: tensor to move or (possibly nested) container of tensors to move.
  113. device: target device
  114. criterion_func: Function to restrict move operation to items meet criterion
  115. Returns:
  116. None
  117. """
  118. if criterion_func(item):
  119. device_copy = item.to(device)
  120. item.data = device_copy.data
  121. return item
  122. elif isinstance(item, list):
  123. return [move_to_device(v, device, criterion_func) for v in item]
  124. elif isinstance(item, tuple):
  125. return tuple([move_to_device(v, device, criterion_func) for v in item])
  126. elif isinstance(item, dict):
  127. return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
  128. else:
  129. return item
  130. class CheckOverflow(object):
  131. '''Checks for overflow in gradient across parallel process'''
  132. def __init__(self,
  133. param_groups=None,
  134. mpu=None,
  135. zero_reduce_scatter=False,
  136. deepspeed=None):
  137. self.mpu = mpu
  138. self.params = [] if param_groups else None
  139. self.zero_reduce_scatter = zero_reduce_scatter
  140. self.deepspeed = deepspeed
  141. self.has_moe_params = False
  142. if param_groups:
  143. for group in param_groups:
  144. for param in group:
  145. self.params.append(param)
  146. if is_moe_param(param):
  147. self.has_moe_params = True
  148. def check_using_norm(self, norm_group, reduce_overflow=True):
  149. # TODO: I don't think reduce_overflow is needed if mpu is None
  150. overflow = -1 in norm_group
  151. overflow_gpu = torch.cuda.FloatTensor([overflow])
  152. if self.has_moe_params:
  153. # In this case, we need to do an all_reduce across
  154. # the expert_parallel_group, so that if there was
  155. # an overflow due to expert weights, we detect it
  156. dist.all_reduce(overflow_gpu,
  157. op=dist.ReduceOp.MAX,
  158. group=groups.get_expert_parallel_group())
  159. if self.mpu is not None:
  160. torch.distributed.all_reduce(overflow_gpu,
  161. op=torch.distributed.ReduceOp.MAX,
  162. group=self.mpu.get_model_parallel_group())
  163. elif reduce_overflow:
  164. dist.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX)
  165. dist.barrier()
  166. overflow = overflow_gpu[0].item()
  167. return bool(overflow)
  168. def check(self, param_groups=None):
  169. params = []
  170. has_moe_params = False
  171. if param_groups is None:
  172. params = self.params
  173. has_moe_params = self.has_moe_params
  174. else:
  175. assert param_groups is not None, \
  176. "self.params and param_groups both cannot be none"
  177. for group in param_groups:
  178. for param in group:
  179. params.append(param)
  180. if is_moe_param(param):
  181. has_moe_params = True
  182. return self.has_overflow(params, has_moe_params=has_moe_params)
  183. # `params` is a list / generator of torch.Variable
  184. def has_overflow_serial(self, params):
  185. for i, p in enumerate(params):
  186. if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
  187. return True
  188. return False
  189. def has_overflow(self, params, has_moe_params=None):
  190. if has_moe_params is None:
  191. has_moe_params = self.has_moe_params
  192. overflow = self.has_overflow_serial(params)
  193. # Since each model parallel GPU carries only part of the model,
  194. # make sure overflow flag is synced across all the model parallel GPUs
  195. overflow_gpu = torch.cuda.ByteTensor([overflow])
  196. # torch.distributed.all_reduce(overflow_gpu,
  197. # op=torch.distributed.ReduceOp.MAX,
  198. # group=mpu.get_model_parallel_group())
  199. if has_moe_params:
  200. # All reduce this across expert_parallel_group, so that if an expert
  201. # overflows, we detect it here
  202. dist.all_reduce(overflow_gpu,
  203. op=dist.ReduceOp.MAX,
  204. group=groups.get_expert_parallel_group())
  205. if self.zero_reduce_scatter:
  206. torch.distributed.all_reduce(overflow_gpu,
  207. op=torch.distributed.ReduceOp.MAX,
  208. group=torch.distributed.group.WORLD)
  209. elif self.mpu is not None:
  210. if self.deepspeed is not None:
  211. using_pipeline = hasattr(self.deepspeed,
  212. 'pipeline_enable_backward_allreduce')
  213. if (using_pipeline
  214. and self.deepspeed.pipeline_enable_backward_allreduce is False
  215. ) or (not using_pipeline
  216. and self.deepspeed.enable_backward_allreduce is False):
  217. torch.distributed.all_reduce(
  218. overflow_gpu,
  219. op=torch.distributed.ReduceOp.MAX,
  220. group=self.mpu.get_data_parallel_group())
  221. torch.distributed.all_reduce(overflow_gpu,
  222. op=torch.distributed.ReduceOp.MAX,
  223. group=self.mpu.get_model_parallel_group())
  224. elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
  225. torch.distributed.all_reduce(overflow_gpu,
  226. op=torch.distributed.ReduceOp.MAX,
  227. group=torch.distributed.group.WORLD)
  228. overflow = overflow_gpu[0].item()
  229. return bool(overflow)
  230. # `x` is a torch.Tensor
  231. @staticmethod
  232. def _has_inf_or_nan(x, i):
  233. try:
  234. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  235. # Pytorch's .sum() creates a one-element tensor of the same type as x
  236. # (which is true for some recent version of pytorch).
  237. cpu_sum = float(x.float().sum())
  238. # More efficient version that can be used if .sum() returns a Python scalar
  239. # cpu_sum = float(x.sum())
  240. except RuntimeError as instance:
  241. # We want to check if inst is actually an overflow exception.
  242. # RuntimeError could come from a different error.
  243. # If so, we still want the exception to propagate.
  244. if "value cannot be converted" not in instance.args[0]:
  245. raise
  246. return True
  247. else:
  248. if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
  249. return True
  250. return False
  251. def _handle_overflow(cpu_sum, x, i):
  252. import math
  253. rank = torch.distributed.get_rank()
  254. if rank == 0:
  255. t_i = -1
  256. for v_i, v in enumerate(x.data.contiguous().view(-1)):
  257. if not math.isfinite(float(v)):
  258. t_i = v_i
  259. break
  260. logger.info(
  261. f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
  262. )
  263. def get_global_norm(norm_list):
  264. """ Compute total from a list of norms
  265. """
  266. total_norm = 0.0
  267. for norm in norm_list:
  268. total_norm += norm**2.0
  269. return sqrt(total_norm)
  270. def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
  271. """Clips gradient norm of an iterable of parameters.
  272. This has been adapted from Nvidia megatron. We add norm averaging
  273. to consider MoE params when calculating norm as they will result
  274. in different norms across different ranks.
  275. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  276. added functionality to handle model parallel parameters. Note that
  277. the gradients are modified in place.
  278. Arguments:
  279. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  280. single Tensor that will have gradients normalized
  281. max_norm (float or int): max norm of the gradients
  282. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  283. infinity norm.
  284. Returns:
  285. Total norm of the parameters (viewed as a single vector).
  286. """
  287. if isinstance(parameters, torch.Tensor):
  288. parameters = [parameters]
  289. parameters = list(filter(lambda p: p.grad is not None, parameters))
  290. max_norm = float(max_norm)
  291. norm_type = float(norm_type)
  292. if norm_type == inf:
  293. total_norm = max(p.grad.data.abs().max() for p in parameters)
  294. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  295. # Take max across all GPUs.
  296. if mpu is not None:
  297. torch.distributed.all_reduce(total_norm_cuda,
  298. op=torch.distributed.ReduceOp.MAX,
  299. group=mpu.get_model_parallel_group())
  300. total_norm = total_norm_cuda[0].item()
  301. else:
  302. total_norm = 0
  303. for p in parameters:
  304. if mpu is not None:
  305. if (mpu.get_model_parallel_rank()
  306. == 0) or is_model_parallel_parameter(p):
  307. param_norm = p.grad.data.norm(norm_type)
  308. total_norm += param_norm.item()**norm_type
  309. else:
  310. param_norm = p.grad.data.float().norm(norm_type)
  311. total_norm += param_norm.item()**norm_type
  312. # Sum across all model parallel GPUs.
  313. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  314. if mpu is not None:
  315. torch.distributed.all_reduce(total_norm_cuda,
  316. op=torch.distributed.ReduceOp.SUM,
  317. group=mpu.get_model_parallel_group())
  318. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  319. # Need to average total_norm across different GPUs due to the presence of moe params
  320. pg = groups.get_data_parallel_group()
  321. scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))
  322. scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)])
  323. dist.all_reduce(scaled_norm_tensor, group=pg)
  324. total_norm = scaled_norm_tensor.item()
  325. clip_coef = max_norm / (total_norm + 1e-6)
  326. if clip_coef < 1:
  327. for p in parameters:
  328. p.grad.data.mul_(clip_coef)
  329. return total_norm
  330. def get_grad_norm(parameters, norm_type=2, mpu=None):
  331. """Get grad norm of an iterable of parameters.
  332. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  333. added functionality to handle model parallel parameters. Note that
  334. the gradients are modified in place. Taken from Nvidia Megatron.
  335. Arguments:
  336. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  337. single Tensor that will have gradients normalized
  338. max_norm (float or int): max norm of the gradients
  339. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  340. infinity norm.
  341. Returns:
  342. Total norm of the parameters (viewed as a single vector).
  343. """
  344. if isinstance(parameters, torch.Tensor):
  345. parameters = [parameters]
  346. parameters = list(filter(lambda p: p.grad is not None, parameters))
  347. norm_type = float(norm_type)
  348. if norm_type == inf:
  349. total_norm = max(p.grad.data.abs().max() for p in parameters)
  350. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  351. # Take max across all GPUs.
  352. if mpu is not None:
  353. torch.distributed.all_reduce(total_norm_cuda,
  354. op=torch.distributed.ReduceOp.MAX,
  355. group=mpu.get_model_parallel_group())
  356. total_norm = total_norm_cuda[0].item()
  357. else:
  358. total_norm = 0.
  359. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  360. for p in parameters:
  361. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  362. if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated:
  363. continue
  364. # Filter to avoid over-counting replicated tensors from tensor
  365. # model parallelism
  366. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  367. continue
  368. param_norm = p.grad.data.float().norm(norm_type)
  369. total_norm += param_norm.item()**norm_type
  370. # Sum across all model parallel GPUs.
  371. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  372. if mpu is not None:
  373. torch.distributed.all_reduce(total_norm_cuda,
  374. op=torch.distributed.ReduceOp.SUM,
  375. group=mpu.get_model_parallel_group())
  376. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  377. if total_norm == float(
  378. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  379. total_norm = -1
  380. return total_norm
  381. def get_grad_zeros(parameters, mpu=None):
  382. """Compute the number of grads with zero values.
  383. This is adapted from get_grad_norm
  384. Arguments:
  385. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  386. single Tensor that will have gradients normalized
  387. Returns:
  388. Total number of params with zero values (viewed as a single vector).
  389. """
  390. if isinstance(parameters, torch.Tensor):
  391. parameters = [parameters]
  392. parameters = list(filter(lambda p: p.grad is not None, parameters))
  393. total_zeros = 0.
  394. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  395. for p in parameters:
  396. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  397. if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated:
  398. continue
  399. # Filter to avoid over-counting replicated tensors from tensor
  400. # model parallelism
  401. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  402. continue
  403. count_zeros = p.grad.numel() - torch.count_nonzero(p.grad)
  404. total_zeros += count_zeros.item()
  405. # Sum across all model parallel GPUs.
  406. total_zeros_cuda = torch.cuda.FloatTensor([float(total_zeros)])
  407. if mpu is not None:
  408. torch.distributed.all_reduce(total_zeros_cuda,
  409. op=torch.distributed.ReduceOp.SUM,
  410. group=mpu.get_model_parallel_group())
  411. total_zeros = total_zeros_cuda[0].item()
  412. return total_zeros
  413. def get_weight_norm(parameters, norm_type=2, mpu=None):
  414. """Get norm of an iterable of parameters.
  415. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  416. added functionality to handle model parallel parameters. Note that
  417. the gradients are modified in place. Taken from Nvidia Megatron.
  418. Arguments:
  419. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  420. single Tensor that will have gradients normalized
  421. max_norm (float or int): max norm of the gradients
  422. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  423. infinity norm.
  424. Returns:
  425. Total norm of the parameters (viewed as a single vector).
  426. """
  427. if isinstance(parameters, torch.Tensor):
  428. parameters = [parameters]
  429. norm_type = float(norm_type)
  430. if norm_type == inf:
  431. total_norm = max(p.data.abs().max() for p in parameters)
  432. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  433. # Take max across all GPUs.
  434. if mpu is not None:
  435. torch.distributed.all_reduce(total_norm_cuda,
  436. op=torch.distributed.ReduceOp.MAX,
  437. group=mpu.get_model_parallel_group())
  438. total_norm = total_norm_cuda[0].item()
  439. else:
  440. total_norm = 0.
  441. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
  442. for p in parameters:
  443. # Pipeline parallelism may replicate parameters. Avoid multi-counting.
  444. if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated:
  445. continue
  446. # Filter to avoid over-counting replicated tensors from tensor
  447. # model parallelism
  448. if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
  449. continue
  450. param_norm = p.data.float().norm(norm_type)
  451. total_norm += param_norm**norm_type
  452. # Sum across all model parallel GPUs.
  453. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  454. if mpu is not None:
  455. torch.distributed.all_reduce(total_norm_cuda,
  456. op=torch.distributed.ReduceOp.SUM,
  457. group=mpu.get_model_parallel_group())
  458. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  459. if total_norm == float(
  460. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  461. total_norm = -1
  462. return total_norm
  463. def is_model_parallel_parameter(p):
  464. return hasattr(p, 'model_parallel') and p.model_parallel
  465. def prefix_sum_inc(weights):
  466. """ Compute an inclusive prefix sum.
  467. Example:
  468. >>> prefix_sum_inc([3,4,5])
  469. [3, 7, 12]
  470. """
  471. weights_ = [w for w in weights]
  472. for x in range(1, len(weights_)):
  473. weights_[x] += weights_[x - 1]
  474. return weights_
  475. def partition_uniform(num_items, num_parts):
  476. parts = [0] * (num_parts + 1)
  477. # First check for the trivial edge case
  478. if num_items <= num_parts:
  479. for p in range(num_parts + 1):
  480. parts[p] = min(p, num_items)
  481. return parts
  482. chunksize = floor(num_items / num_parts)
  483. for p in range(num_parts):
  484. parts[p] = min(chunksize * p, num_items)
  485. parts[num_parts] = num_items
  486. return parts
  487. def _lprobe(weights, num_parts, bottleneck):
  488. num_items = len(weights)
  489. total_weight = weights[-1]
  490. # initialize partitioning
  491. parts = [0] * (num_parts + 1)
  492. for p in range(1, num_parts + 1):
  493. parts[p] = num_items
  494. bsum = bottleneck # running sum of target weight for pth partition
  495. chunksize = num_items // num_parts
  496. step = chunksize
  497. for p in range(1, num_parts):
  498. # Jump to the next bucket
  499. while (step < num_items) and (weights[step] < bsum):
  500. step += chunksize
  501. # Find the end index of partition p
  502. parts[p] = bisect_left(weights,
  503. bsum,
  504. lo=step - chunksize,
  505. hi=min(step,
  506. num_items))
  507. # Nothing more to partition, return early
  508. if parts[p] == num_items:
  509. # See if the current partition is overweight.
  510. part_size = weights[-1] - weights[parts[p - 1]]
  511. return parts, part_size < bottleneck
  512. # Next partition target
  513. bsum = weights[parts[p] - 1] + bottleneck
  514. return parts, bsum >= total_weight
  515. def _rb_partition_balanced(weights, num_parts, eps):
  516. total_weight = weights[-1]
  517. lower = total_weight / num_parts # best case heaviest partition
  518. upper = total_weight # worst case heaviest partition
  519. # Do a binary search for the best partitioning
  520. while upper > lower + eps:
  521. mid = lower + ((upper - lower) / 2)
  522. parts, success = _lprobe(weights, num_parts, mid)
  523. if success:
  524. upper = mid
  525. else:
  526. lower = mid + eps
  527. return upper
  528. def partition_balanced(weights, num_parts, eps=1e-3):
  529. num_items = len(weights)
  530. # First check for the trivial edge case
  531. if num_items <= num_parts:
  532. return partition_uniform(num_items, num_parts)
  533. weights_ = prefix_sum_inc(weights)
  534. # Find the smallest bottleneck (weight of heaviest partition)
  535. bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)
  536. # Now compute that partitioning
  537. parts, success = _lprobe(weights_, num_parts, bottleneck)
  538. assert success
  539. return parts
  540. class PartitionedTensor:
  541. def __init__(self, tensor, group, partition_meta=None):
  542. super().__init__()
  543. self.group = group
  544. self.num_parts = dist.get_world_size(group=self.group)
  545. self.rank = dist.get_rank(group=self.group)
  546. self.orig_size = list(tensor.size())
  547. self.orig_device = tensor.device
  548. self.local_data, self.partition = self._partition_tensor(tensor)
  549. @classmethod
  550. def from_meta(cls, meta, local_part, group, device='cuda'):
  551. assert meta.dtype == torch.long
  552. dummy = torch.ones(dist.get_world_size(group=group))
  553. part_obj = cls(tensor=dummy, group=group)
  554. meta = meta.tolist()
  555. # [N, list0, ..., listN-1]
  556. part_obj.orig_size = meta[1:(1 + meta[0])]
  557. meta = meta[1 + meta[0]:]
  558. part_obj.orig_device = device
  559. part_obj.local_data = local_part.detach()
  560. part_obj.group = group
  561. # Partition is encoded like the rowptr of a CSR matrix:
  562. # [num_parts, rank, 0, part_1, ..., part_num_parts]
  563. # TODO: support shuffle between different partition granularities
  564. assert part_obj.num_parts == meta[0]
  565. assert part_obj.rank == meta[1]
  566. part_obj.partition = meta[2:] # length num_parts+1
  567. return part_obj
  568. def _partition_tensor(self, tensor):
  569. partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
  570. start = partition[self.rank]
  571. length = partition[self.rank + 1] - start
  572. tensor_part = tensor.detach().contiguous().view(-1).narrow(
  573. 0,
  574. start=start,
  575. length=length).clone()
  576. return tensor_part, partition
  577. def full(self, device=None):
  578. if device is None:
  579. device = self.orig_device
  580. # Allocate the full tensor as a flat buffer.
  581. full_numel = prod(self.full_size())
  582. flat_tensor = torch.zeros([full_numel],
  583. dtype=self.local_data.dtype,
  584. device=device)
  585. # Prepare all-gather buffer
  586. partition_tensors = []
  587. for part_id in range(self.num_parts):
  588. part_size = self.partition[part_id + 1] - self.partition[part_id]
  589. buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
  590. if part_id == self.rank:
  591. buf.copy_(self.local_data)
  592. partition_tensors.append(buf)
  593. # Collect the full tensor
  594. dist.all_gather(partition_tensors,
  595. partition_tensors[self.rank],
  596. group=self.group)
  597. for i in range(len(partition_tensors)):
  598. partition_tensors[i].data = torch.zeros(1)
  599. partition_tensors[i] = None
  600. return flat_tensor.view(self.full_size()).clone().detach()
  601. def to_meta(self):
  602. """Returns a torch.LongTensor that encodes partitioning information.
  603. Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
  604. communication.
  605. Returns:
  606. torch.LongTensor: a tensor encoding the meta-information for the partitioning
  607. """
  608. meta = []
  609. meta.append(len(self.orig_size))
  610. meta += list(self.orig_size)
  611. meta.append(self.num_parts)
  612. meta.append(self.rank)
  613. meta += self.partition
  614. return torch.LongTensor(data=meta).to(self.orig_device)
  615. def data(self):
  616. return self.local_data
  617. def local_size(self):
  618. return self.local_data.size()
  619. def full_size(self):
  620. return self.orig_size
  621. mem_alloced = 0
  622. mem_cached = 0
  623. def memory_status(msg, print_rank=-1, reset_max=False):
  624. global mem_alloced, mem_cached
  625. rank = dist.get_rank()
  626. if print_rank != -1 and rank != print_rank:
  627. return
  628. torch.cuda.synchronize()
  629. if reset_max:
  630. torch.cuda.reset_max_memory_cached()
  631. torch.cuda.reset_max_memory_allocated()
  632. new_alloced = torch.cuda.memory_allocated()
  633. new_cached = torch.cuda.memory_cached()
  634. delta_alloced = new_alloced - mem_alloced
  635. delta_cached = new_cached - mem_cached
  636. mem_cached = new_cached
  637. mem_alloced = new_alloced
  638. max_alloced = torch.cuda.max_memory_allocated()
  639. max_cached = torch.cuda.max_memory_cached()
  640. # convert to GB for printing
  641. new_alloced /= 1024**3
  642. new_cached /= 1024**3
  643. delta_alloced /= 1024**3
  644. delta_cached /= 1024**3
  645. max_alloced /= 1024**3
  646. max_cached /= 1024**3
  647. print(
  648. f'RANK={rank} MEMSTATS',
  649. msg,
  650. f'device={torch.cuda.current_device()} '
  651. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  652. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
  653. )
  654. def get_ma_status():
  655. if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
  656. return 0
  657. return torch.cuda.memory_allocated()
  658. def see_memory_usage(message, force=False):
  659. if not force:
  660. return
  661. if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
  662. return
  663. # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
  664. gc.collect()
  665. # Print message except when distributed but not rank 0
  666. logger.info(message)
  667. logger.info(
  668. f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
  669. Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
  670. CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
  671. Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
  672. vm_stats = psutil.virtual_memory()
  673. used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
  674. logger.info(
  675. f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
  676. # get the peak memory to report correct data, so reset the counter for the next call
  677. if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
  678. torch.cuda.reset_peak_memory_stats()
  679. def call_to_str(base, *args, **kwargs):
  680. """Construct a string representation of a call.
  681. Args:
  682. base (str): name of the call
  683. args (tuple, optional): args to ``base``
  684. kwargs (dict, optional): kwargs supplied to ``base``
  685. Returns:
  686. str: A string representation of base(*args, **kwargs)
  687. """
  688. name = f'{base}('
  689. if args:
  690. name += ', '.join(repr(arg) for arg in args)
  691. if kwargs:
  692. name += ', '
  693. if kwargs:
  694. name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
  695. name += ')'
  696. return name