checkpointing.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Use to partition the activations stored for backward propagation
  6. Therefore reduces the memory consumption
  7. Also implements CPU checkpointing and contiguous memory checkpointing
  8. Reduces memory consumption and memory fragmentation
  9. Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
  10. b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
  11. """
  12. # Parts of the code here are adapted from PyTorch
  13. # repo: https://github.com/pytorch/pytorch
  14. import copy
  15. import torch
  16. import contextlib
  17. from deepspeed import comm as dist
  18. import weakref
  19. import mmap
  20. from torch import _C
  21. from deepspeed.runtime.config import DeepSpeedConfig
  22. from deepspeed.utils import logger
  23. from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
  24. from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
  25. from deepspeed.accelerator import get_accelerator
  26. # DeepSpeed Checkpointing Enabled or Disabled
  27. deepspeed_checkpointing_enabled = False
  28. # MP parameters
  29. mpu = None
  30. mp_rank = None
  31. mp_size = None
  32. mp_group = None
  33. # Model Parameters
  34. num_layers = None
  35. # Checkpointing buffers
  36. contiguous_data_buffers = []
  37. data_offsets = []
  38. contiguous_size_buffers = []
  39. size_offsets = []
  40. timers = None
  41. # optimization flags
  42. PARTITION_ACTIVATIONS = False
  43. CPU_CHECKPOINT = False
  44. CONTIGUOUS_CHECKPOINTING = False
  45. SYNCHRONIZE = False
  46. PROFILE_TIME = False
  47. # Default name for the model parallel rng tracker.
  48. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
  49. transport_stream = None
  50. cuda_device = None
  51. def detach_variable(inputs, device=None):
  52. if isinstance(inputs, tuple):
  53. out = []
  54. for inp in inputs:
  55. if not isinstance(inp, torch.Tensor):
  56. out.append(inp)
  57. continue
  58. requires_grad = inp.requires_grad
  59. if device is not None:
  60. x = inp.to(device=device)
  61. else:
  62. x = inp
  63. x = x.detach()
  64. x.requires_grad = requires_grad
  65. out.append(x)
  66. return tuple(out)
  67. else:
  68. raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
  69. def _set_cuda_rng_state(new_state, device=-1):
  70. """Sets the random number generator state of the current GPU.
  71. Arguments:
  72. new_state (torch.ByteTensor): The desired state
  73. This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
  74. with a single change: the input state is not cloned. Cloning caused
  75. major performance issues for +4 GPU cases.
  76. """
  77. if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
  78. # older PyTorch
  79. def cb():
  80. with get_accelerator().device(device):
  81. _C._cuda_setRNGState(new_state)
  82. else:
  83. # newer PyTorch
  84. if device == -1:
  85. device = torch.device(get_accelerator().device_name())
  86. elif isinstance(device, str):
  87. device = torch.device(device)
  88. elif isinstance(device, int):
  89. device = torch.device(get_accelerator().device_name(), device)
  90. def cb():
  91. idx = device.index
  92. if idx is None:
  93. idx = get_accelerator().current_device()
  94. default_generator = get_accelerator().default_generator(idx)
  95. default_generator.set_state(new_state)
  96. get_accelerator().lazy_call(cb)
  97. class CudaRNGStatesTracker:
  98. """Tracker for the cuda RNG states.
  99. Using the `add` method, a cuda rng state is initialized based on
  100. the input `seed` and is assigned to `name`. Later, by forking the
  101. rng state, we can perform operations and return to our starting
  102. cuda state.
  103. """
  104. def __init__(self):
  105. # Map from a string name to the cuda rng state.
  106. self.states_ = {}
  107. # Seeds are just for book keeping and ensure no seed is set twice.
  108. self.seeds_ = set()
  109. def reset(self):
  110. """Set to the initial state (no tracker)."""
  111. self.states_ = {}
  112. self.seeds_ = set()
  113. def get_states(self):
  114. """Get rng states. Copy the dictionary so we have direct
  115. pointers to the states, not just a pointer to the dictionary."""
  116. return copy.copy(self.states_)
  117. def set_states(self, states):
  118. """Set the rng states. For efficiency purposes, we do not check
  119. the size of seed for compatibility."""
  120. self.states_ = states
  121. def add(self, name, seed):
  122. """Track the rng state."""
  123. # Check seed is not already used.
  124. if seed in self.seeds_:
  125. raise Exception('seed {} already exists'.format(seed))
  126. self.seeds_.add(seed)
  127. # Check that state is not already defined.
  128. if name in self.states_:
  129. raise Exception('cuda rng state {} already exists'.format(name))
  130. # Get the current rng state.
  131. orig_rng_state = get_accelerator().get_rng_state()
  132. # Set the new state and store it.
  133. get_accelerator().manual_seed(seed)
  134. self.states_[name] = get_accelerator().get_rng_state()
  135. # Reset rng state to what it was.
  136. _set_cuda_rng_state(orig_rng_state)
  137. @contextlib.contextmanager
  138. def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
  139. """Fork the cuda rng state, perform operations, and exit with
  140. the original state."""
  141. # Check if we have added the state
  142. if name not in self.states_:
  143. raise Exception('cuda rng state {} is not added'.format(name))
  144. # Store current rng state.
  145. orig_cuda_rng_state = get_accelerator().get_rng_state()
  146. # Set rng state to the desired one
  147. _set_cuda_rng_state(self.states_[name])
  148. # Do the stuff we wanted to do.
  149. try:
  150. yield
  151. finally:
  152. # Update the current rng state for later use.
  153. self.states_[name] = get_accelerator().get_rng_state()
  154. # And set the state to the original state we started with.
  155. _set_cuda_rng_state(orig_cuda_rng_state)
  156. # RNG tracker object.
  157. _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
  158. def get_cuda_rng_tracker():
  159. """Get cuda rng tracker."""
  160. return _CUDA_RNG_STATE_TRACKER
  161. def model_parallel_cuda_manual_seed(seed):
  162. """Initialize model parallel cuda seed.
  163. This function should be called after the model parallel is
  164. initialized. Also, no get_accelerator().manual_seed should be called
  165. after this function. Basically, this is replacement for that
  166. function.
  167. Two set of RNG states are tracked:
  168. default state: This is for data parallelism and is the same among a
  169. set of model parallel GPUs but different across
  170. different model parallel groups. This is used for
  171. example for dropout in the non-model-parallel regions.
  172. model-parallel state: This state is different among a set of model
  173. parallel GPUs, but the same across data parallel
  174. groups. This is used for example for dropout in
  175. model parallel regions.
  176. """
  177. global mpu
  178. tp_rank = bwc_tensor_model_parallel_rank(mpu)
  179. # 2718 is just for fun and any POSITIVE value will work.
  180. offset = seed + 2718
  181. model_parallel_seed = offset + tp_rank
  182. # Data parallel gets the original seed.
  183. data_parallel_seed = seed
  184. if dist.get_rank() == 0:
  185. logger.info(
  186. '> initializing model parallel cuda seeds on global rank {}, '
  187. 'model parallel rank {}, and data parallel rank {} with '
  188. 'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
  189. mpu.get_data_parallel_rank(),
  190. model_parallel_seed, data_parallel_seed), )
  191. _CUDA_RNG_STATE_TRACKER.reset()
  192. # Set the default state.
  193. get_accelerator().manual_seed(data_parallel_seed)
  194. # and model parallel state.
  195. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
  196. def get_partition_start(item):
  197. global mp_rank, mp_size, mp_group
  198. size = item.numel()
  199. partition_size = size / mp_size
  200. start = partition_size * mp_rank
  201. return int(start)
  202. def get_partition_size(item):
  203. global mp_rank, mp_size, mp_group
  204. size = item.numel()
  205. assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
  206. partition_size = size / mp_size
  207. return int(partition_size)
  208. def gather_partitioned_activations(tensors, device=None):
  209. global mp_rank, mp_size, mp_group
  210. assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
  211. inputs = []
  212. num_args = int(len(tensors) / 2)
  213. for i in range(num_args):
  214. item = tensors[2 * i]
  215. size = tensors[2 * i + 1]
  216. if not is_activation_to_checkpoint(item):
  217. inputs.append(item)
  218. continue
  219. # don't need to do all_gather if model parallel is not enabled
  220. if mp_group is None or mp_size == 1:
  221. item = item.view(list(size.numpy()))
  222. inputs.append(item)
  223. continue
  224. partition_size = item.numel()
  225. tensor_size = partition_size * mp_size
  226. if device is not None:
  227. flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
  228. else:
  229. flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
  230. partitions = []
  231. for i in range(mp_size):
  232. part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
  233. if i == mp_rank:
  234. part_i.copy_(item)
  235. partitions.append(part_i)
  236. dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
  237. input_tensor = flat_tensor.view(list(size.numpy()))
  238. item.data = input_tensor.data
  239. inputs.append(item)
  240. return tuple(inputs)
  241. def extract_tensors(all_objects):
  242. """
  243. Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
  244. The order of tensors and non-tensors is preserved in their respective output groups.
  245. Parameters:
  246. all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
  247. Returns:
  248. tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
  249. """
  250. tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
  251. non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
  252. tensor_flags = [torch.is_tensor(v) for v in all_objects]
  253. if type(all_objects) is tuple:
  254. return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
  255. return tensor_objects, non_tensor_objects, tensor_flags
  256. def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
  257. """
  258. Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
  259. Parameters:
  260. tensor_objects (list/tuple): Tensors to merge.
  261. non_tensor_objects (list/tuple): Non-tensors to merge.
  262. tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
  263. Returns:
  264. tuple: Merge of tensors and non-tensors
  265. """
  266. merged_objects = []
  267. tensor_idx = 0
  268. non_tensor_idx = 0
  269. real_tensor_flags = None
  270. # remove the flags that are assigned to the size of the flattened tensors
  271. if PARTITION_ACTIVATIONS:
  272. real_tensor_flags = []
  273. previous_flag = False
  274. for flag in tensor_flags:
  275. if previous_flag:
  276. previous_flag = False
  277. continue
  278. previous_flag = flag
  279. real_tensor_flags.append(flag)
  280. else:
  281. real_tensor_flags = tensor_flags
  282. for is_tensor in real_tensor_flags:
  283. if is_tensor:
  284. merged_objects.append(tensor_objects[tensor_idx])
  285. tensor_idx += 1
  286. else:
  287. merged_objects.append(non_tensor_objects[non_tensor_idx])
  288. non_tensor_idx += 1
  289. return tuple(merged_objects)
  290. def is_activation_to_checkpoint(item):
  291. """
  292. Is an activation to be checkpointed
  293. """
  294. global mp_size
  295. return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
  296. def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
  297. global contiguous_data_buffers, data_offsets
  298. inputs = []
  299. num_non_fp_tensors = 0
  300. for arg_index, item in enumerate(args):
  301. if not is_activation_to_checkpoint(item):
  302. inputs.append(item)
  303. num_non_fp_tensors += 1
  304. continue
  305. i = arg_index - num_non_fp_tensors
  306. partition_size = get_partition_size(item)
  307. partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
  308. buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
  309. if contiguous_checkpoint:
  310. if i >= len(contiguous_data_buffers):
  311. tensor_list = [
  312. torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
  313. for _ in range(num_layers)
  314. ]
  315. contiguous_data_buffers.append(tensor_list)
  316. data_offsets.append(0)
  317. elif contiguous_data_buffers[i] is None:
  318. tensor_list = [
  319. torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
  320. for _ in range(num_layers)
  321. ]
  322. contiguous_data_buffers[i] = tensor_list
  323. data_offsets[i] = 0
  324. # Because the 'new_empty' returns uninitialized pages,
  325. # the pages need to be populated during the cudaMemcpy time
  326. # which increases the data copy time. To avoid this, we
  327. # pre-populate these pages by simply writing 0 ahead of
  328. # the actual cudaMemcpy operation time. Due to the
  329. # previously launched GPU kernels, there is a small
  330. # window of time here for CPUs to populate pages asynchronously.
  331. contiguous_data_buffers[i][data_offsets[i]].data[range(
  332. 0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
  333. int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
  334. contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
  335. data_offsets[i] = data_offsets[i] + 1
  336. inputs.append(contiguous_partition)
  337. else:
  338. partition = partition.cpu() if CPU_CHECKPOINT else partition
  339. inputs.append(partition)
  340. return inputs
  341. def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
  342. global contiguous_size_buffers, size_offsets
  343. new_args = []
  344. num_non_fp_tensors = 0
  345. for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
  346. size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
  347. if not is_activation_to_checkpoint(arg):
  348. new_args.append(arg)
  349. new_args.append(size)
  350. num_non_fp_tensors += 1
  351. continue
  352. arg.data = inp.data
  353. new_args.append(arg)
  354. i = arg_index - num_non_fp_tensors
  355. if contiguous_checkpoint:
  356. numel = size.numel()
  357. if i >= len(contiguous_size_buffers):
  358. tmp = torch.tensor(())
  359. contiguous_size_buffers.append(
  360. tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
  361. size_offsets.append(0)
  362. elif contiguous_size_buffers[i] is None:
  363. tmp = torch.tensor(())
  364. contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
  365. size_offsets[i] = 0
  366. contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
  367. contiguous_size = contiguous_size.view_as(size)
  368. size_offsets[i] = size_offsets[i] + numel
  369. new_args.append(contiguous_size)
  370. else:
  371. new_args.append(size)
  372. return new_args
  373. def get_cpu_activations_for_backward(args, inputs):
  374. new_args = []
  375. for i, (arg, inp) in enumerate(zip(args, inputs)):
  376. if not is_activation_to_checkpoint(arg):
  377. new_args.append(arg)
  378. continue
  379. arg.data = inp.data
  380. new_args.append(arg)
  381. return new_args
  382. class CheckpointFunction(torch.autograd.Function):
  383. """This function is adapted from torch.utils.checkpoint with
  384. two main changes:
  385. 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
  386. 2) the states in the model parallel tracker are also properly
  387. tracked/set/reset.
  388. 3) Performance activation partitioning, contiguous memory optimization
  389. 4) CPU Checkpointing
  390. 5) Profile forward and backward functions
  391. """
  392. @staticmethod
  393. def forward(ctx, run_function, all_outputs, *args):
  394. global mpu, timers, SYNCHRONIZE, PROFILE_TIME
  395. def save_args_for_backward(*all_args):
  396. tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
  397. ctx.deepspeed_saved_tensors = tensor_args
  398. ctx.non_tensor_args = non_tensor_args
  399. ctx.tensor_flags = tensor_flags
  400. if SYNCHRONIZE:
  401. get_accelerator().synchronize()
  402. if timers is None and PROFILE_TIME:
  403. timers = Timers()
  404. if PROFILE_TIME:
  405. timers(FORWARD_GLOBAL_TIMER).start()
  406. ctx.run_function = run_function
  407. global num_layers
  408. global mp_rank, mp_size, mp_group
  409. global contiguous_data_buffers, contiguous_size_buffers
  410. global data_offsets, size_offsets
  411. if mp_rank is None:
  412. if mpu is not None:
  413. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  414. mp_rank = mpu.get_tensor_model_parallel_rank()
  415. mp_size = mpu.get_tensor_model_parallel_world_size()
  416. mp_group = mpu.get_tensor_model_parallel_group()
  417. else:
  418. mp_rank = mpu.get_model_parallel_rank()
  419. mp_size = mpu.get_model_parallel_world_size()
  420. mp_group = mpu.get_model_parallel_group()
  421. else:
  422. mp_rank = 0
  423. mp_size = 1
  424. mp_group = None
  425. global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
  426. if cuda_device is None:
  427. see_memory_usage("First Forward Beginning", force=False)
  428. if dist.get_rank() == 0:
  429. logger.info(f"Activation Checkpointing Information")
  430. logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
  431. logger.info(
  432. f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
  433. logger.info(f"----Synchronization {SYNCHRONIZE}")
  434. logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
  435. cuda_device = get_accelerator().current_device_name()
  436. transport_stream = get_accelerator().Stream(device=cuda_device)
  437. if PARTITION_ACTIVATIONS:
  438. inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
  439. elif CPU_CHECKPOINT:
  440. inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
  441. # just in case something funky is happening such as reuse of inputs
  442. inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
  443. # Copy the rng states.
  444. ctx.fwd_cpu_rng_state = torch.get_rng_state()
  445. ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
  446. ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  447. see_memory_usage("Before running forward on the layer", force=False)
  448. # ctx.save_for_backward(*args)
  449. with torch.no_grad():
  450. outputs = run_function(*inputs_cuda)
  451. see_memory_usage("After running forward on the layer", force=False)
  452. del inputs_cuda
  453. if PARTITION_ACTIVATIONS:
  454. new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
  455. assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
  456. save_args_for_backward(*new_args)
  457. elif CPU_CHECKPOINT:
  458. new_args = get_cpu_activations_for_backward(args, inputs)
  459. save_args_for_backward(*new_args)
  460. else:
  461. save_args_for_backward(*args)
  462. if PROFILE_TIME:
  463. timers(FORWARD_GLOBAL_TIMER).stop()
  464. timers.log([FORWARD_GLOBAL_TIMER])
  465. if SYNCHRONIZE:
  466. get_accelerator().synchronize()
  467. # Tensors returned from forward() may not be differentiable.
  468. if torch.is_tensor(outputs):
  469. non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
  470. else:
  471. non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
  472. ctx.mark_non_differentiable(*non_grad_outputs)
  473. if torch.is_tensor(outputs):
  474. all_outputs += [outputs]
  475. return outputs
  476. else:
  477. all_outputs += outputs
  478. outputs, _, _ = extract_tensors(all_objects=outputs)
  479. return tuple(outputs)
  480. @staticmethod
  481. def backward(ctx, *grads):
  482. global timers
  483. see_memory_usage("In backward", force=False)
  484. # removing pointers to the contiguous buffer memory
  485. # so that they can be garbage collected once the checkpoints
  486. # have been used
  487. if SYNCHRONIZE:
  488. get_accelerator().synchronize()
  489. if PROFILE_TIME:
  490. timers('backward').start()
  491. if CONTIGUOUS_CHECKPOINTING:
  492. global data_offsets, size_offsets
  493. global contiguous_data_buffers, contiguous_size_buffers
  494. for buffers in contiguous_data_buffers:
  495. buffers = []
  496. # frees up all the pointers to the checkpoints except for the ones
  497. # stored by save for backward
  498. contiguous_data_buffers = []
  499. contiguous_size_buffers = []
  500. data_offsets = []
  501. size_offsets = []
  502. see_memory_usage("In backward checkpointing code", force=False)
  503. if not torch.autograd._is_checkpoint_valid():
  504. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  505. "please use .backward() if possible")
  506. global cuda_device, transport_stream, PARTITION_ACTIVATIONS
  507. if PARTITION_ACTIVATIONS:
  508. # with get_accelerator().stream(transport_stream):
  509. inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
  510. device=cuda_device if CPU_CHECKPOINT else None)
  511. detached_inputs = detach_variable(inputs)
  512. elif CPU_CHECKPOINT:
  513. inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
  514. detached_inputs = detach_variable(inputs)
  515. else:
  516. inputs = ctx.deepspeed_saved_tensors
  517. detached_inputs = detach_variable(inputs)
  518. # Add non tensor input args
  519. detached_inputs = merge_tensors(tensor_objects=detached_inputs,
  520. non_tensor_objects=ctx.non_tensor_args,
  521. tensor_flags=ctx.tensor_flags)
  522. # Store the current states.
  523. bwd_cpu_rng_state = torch.get_rng_state()
  524. bwd_cuda_rng_state = get_accelerator().get_rng_state()
  525. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  526. # Set the states to what it used to be before the forward pass.
  527. torch.set_rng_state(ctx.fwd_cpu_rng_state)
  528. _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
  529. get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
  530. # if PARTITION_ACTIVATIONS:
  531. # current_stream=get_accelerator().current_stream()
  532. # current_stream.wait_stream(transport_stream)
  533. see_memory_usage("In backward checkpointing code before forward", force=False)
  534. with torch.enable_grad():
  535. outputs = ctx.run_function(*detached_inputs)
  536. see_memory_usage("In backward checkpointing code after forward", force=False)
  537. # Set the states back to what it was at the start of this function.
  538. torch.set_rng_state(bwd_cpu_rng_state)
  539. _set_cuda_rng_state(bwd_cuda_rng_state)
  540. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  541. if isinstance(outputs, torch.Tensor):
  542. outputs = (outputs, )
  543. # Filter out non tensor outputs
  544. outputs, _, _ = extract_tensors(all_objects=outputs)
  545. # Construct arguments to autograd.backward().
  546. # This is usually just outputs and grads, but forward() can return tensors that
  547. # are not differentiable.
  548. output_tensors = []
  549. grad_tensors = []
  550. for out, grad in zip(outputs, grads):
  551. if out.requires_grad:
  552. output_tensors.append(out)
  553. grad_tensors.append(grad)
  554. see_memory_usage("In backward checkpointing code before backward", force=False)
  555. torch.autograd.backward(output_tensors, grad_tensors)
  556. # Force clear our stashed tensors to prevent a memory leak in certain scenarios
  557. ctx.deepspeed_saved_tensors = None
  558. ctx.non_tensor_args = None
  559. ctx.tensor_flags = None
  560. see_memory_usage("After backward checkpointing code after backward", force=False)
  561. if PROFILE_TIME:
  562. timers('backward').stop()
  563. timers.log(['backward'])
  564. if SYNCHRONIZE:
  565. get_accelerator().synchronize()
  566. ret_list = [None, None] # first None for ctx
  567. for inp in detached_inputs:
  568. if torch.is_tensor(inp):
  569. ret_list.append(inp.grad)
  570. else:
  571. ret_list.append(None)
  572. return tuple(ret_list)
  573. def non_reentrant_checkpoint(function, *args):
  574. """This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module
  575. This function is aim to solve the back probagation error raised from all input requires no grad.
  576. * has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
  577. * can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`
  578. Main modifications compared to the implementation of torch:
  579. 1. adapt to the signature of `checkpoint` function in this module
  580. 2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
  581. 3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
  582. 4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
  583. 5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
  584. """
  585. global mpu, timers, SYNCHRONIZE, PROFILE_TIME
  586. deepspeed_saved_tensors = None
  587. non_tensor_args = None
  588. tensor_flags = None
  589. def save_args_for_backward(*all_args):
  590. """keep this function to reduce the modification from original implementation"""
  591. nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
  592. tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
  593. deepspeed_saved_tensors = tensor_args
  594. non_tensor_args = non_tensor_args
  595. tensor_flags = tensor_flags
  596. if SYNCHRONIZE:
  597. get_accelerator().synchronize()
  598. if timers is None and PROFILE_TIME:
  599. timers = Timers()
  600. if PROFILE_TIME:
  601. timers(FORWARD_GLOBAL_TIMER).start()
  602. global num_layers
  603. global mp_rank, mp_size, mp_group
  604. global contiguous_data_buffers, contiguous_size_buffers
  605. global data_offsets, size_offsets
  606. if mp_rank is None:
  607. if mpu is not None:
  608. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  609. mp_rank = mpu.get_tensor_model_parallel_rank()
  610. mp_size = mpu.get_tensor_model_parallel_world_size()
  611. mp_group = mpu.get_tensor_model_parallel_group()
  612. else:
  613. mp_rank = mpu.get_model_parallel_rank()
  614. mp_size = mpu.get_model_parallel_world_size()
  615. mp_group = mpu.get_model_parallel_group()
  616. else:
  617. mp_rank = 0
  618. mp_size = 1
  619. mp_group = None
  620. global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
  621. if cuda_device is None:
  622. see_memory_usage("First Forward Beginning", force=False)
  623. if dist.get_rank() == 0:
  624. logger.info(f"Activation Checkpointing Information")
  625. logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
  626. logger.info(
  627. f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
  628. logger.info(f"----Synchronization {SYNCHRONIZE}")
  629. logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
  630. cuda_device = get_accelerator().current_device_name()
  631. transport_stream = get_accelerator().Stream(device=cuda_device)
  632. if PARTITION_ACTIVATIONS:
  633. inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
  634. elif CPU_CHECKPOINT:
  635. inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
  636. # just in case something funky is happening such as reuse of inputs
  637. inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
  638. # Copy the rng states.
  639. fwd_cpu_rng_state = torch.get_rng_state()
  640. fwd_cuda_rng_state = get_accelerator().get_rng_state()
  641. fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  642. if PARTITION_ACTIVATIONS:
  643. new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
  644. assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
  645. save_args_for_backward(*new_args)
  646. elif CPU_CHECKPOINT:
  647. new_args = get_cpu_activations_for_backward(args, inputs)
  648. save_args_for_backward(*new_args)
  649. else:
  650. save_args_for_backward(*args)
  651. class Holder():
  652. """the place holder object used as activations to save memory"""
  653. pass
  654. # weakref seems utilized to discover the tensor deletion before a whole
  655. # forward backward pair loop finished
  656. storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  657. weak_holder_list = []
  658. leaf_tensors = []
  659. backward_visited_leaf_nodes = 0
  660. def checkpoint_pack(tensor_from_forward):
  661. """used to record the activation order in the `weak_holder_list`
  662. the activation order in holder list is consistent between the first forward and recomputing forward.
  663. * the jit compiled forward will break the order consistency *
  664. """
  665. res = Holder()
  666. weak_holder_list.append(weakref.ref(res))
  667. # if this is a leaf tensor, save it for backward progression trace
  668. # leaf tensor used to be input or parameters, which is not activations and
  669. # has no memory overhead
  670. if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
  671. leaf_tensors.append(tensor_from_forward)
  672. return res
  673. def checkpoint_unpack(holder_from_backward):
  674. """retrieve the activations from recompute"""
  675. nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
  676. # if this is the first step of backward probagation, recompute the graph and save
  677. # all the activations with the same order as `checkpoint_pack` does
  678. if len(storage) == 0:
  679. unpack_counter = 0
  680. def replay_pack(tensor_from_replay):
  681. """save recompute activations"""
  682. nonlocal unpack_counter
  683. unpack_counter += 1
  684. if weak_holder_list[unpack_counter - 1]() is None:
  685. return
  686. detached_activations = tensor_from_replay.detach()
  687. storage[weak_holder_list[unpack_counter - 1]()] = detached_activations
  688. return
  689. def replay_unpack(none_value):
  690. """recompute graph need not to backward"""
  691. raise RuntimeError("You are calling backwards on a tensor that is never exposed.")
  692. global timers
  693. see_memory_usage("In backward", force=False)
  694. # removing pointers to the contiguous buffer memory
  695. # so that they can be garbage collected once the checkpoints
  696. # have been used
  697. if SYNCHRONIZE:
  698. get_accelerator().synchronize()
  699. if PROFILE_TIME:
  700. timers('backward').start()
  701. if CONTIGUOUS_CHECKPOINTING:
  702. global data_offsets, size_offsets
  703. global contiguous_data_buffers, contiguous_size_buffers
  704. for buffers in contiguous_data_buffers:
  705. buffers = []
  706. # frees up all the pointers to the checkpoints except for the ones
  707. # stored by save for backward
  708. contiguous_data_buffers = []
  709. contiguous_size_buffers = []
  710. data_offsets = []
  711. size_offsets = []
  712. see_memory_usage("In backward checkpointing code", force=False)
  713. if not torch.autograd._is_checkpoint_valid():
  714. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  715. "please use .backward() if possible")
  716. global cuda_device, transport_stream, PARTITION_ACTIVATIONS
  717. # gather inputs which is partitioned or checkpointed before first forward
  718. if PARTITION_ACTIVATIONS:
  719. # with get_accelerator().stream(transport_stream):
  720. inputs = gather_partitioned_activations(deepspeed_saved_tensors,
  721. device=cuda_device if CPU_CHECKPOINT else None)
  722. detached_inputs = detach_variable(inputs)
  723. elif CPU_CHECKPOINT:
  724. inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
  725. detached_inputs = detach_variable(inputs)
  726. else:
  727. inputs = deepspeed_saved_tensors
  728. detached_inputs = detach_variable(inputs)
  729. # Add non tensor input args
  730. detached_inputs = merge_tensors(tensor_objects=detached_inputs,
  731. non_tensor_objects=non_tensor_args,
  732. tensor_flags=tensor_flags)
  733. # Store the current states.
  734. bwd_cpu_rng_state = torch.get_rng_state()
  735. bwd_cuda_rng_state = get_accelerator().get_rng_state()
  736. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  737. # Set the states to what it used to be before the forward pass.
  738. torch.set_rng_state(fwd_cpu_rng_state)
  739. _set_cuda_rng_state(fwd_cuda_rng_state)
  740. get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)
  741. see_memory_usage("In backward checkpointing code before forward", force=False)
  742. with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
  743. _unused = function(*detached_inputs)
  744. see_memory_usage("In backward checkpointing code after forward", force=False)
  745. # Set the states back to what it was at the start of this function.
  746. torch.set_rng_state(bwd_cpu_rng_state)
  747. _set_cuda_rng_state(bwd_cuda_rng_state)
  748. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  749. deepspeed_saved_tensors = None
  750. non_tensor_args = None
  751. tensor_flags = None
  752. if holder_from_backward not in storage:
  753. raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
  754. " recomputation being triggered in between, this is not currently supported.")
  755. return storage[holder_from_backward]
  756. def after_backward_hook(_nonuse_grads):
  757. """the hook registered to all leaf tensors"""
  758. nonlocal leaf_tensors, backward_visited_leaf_nodes
  759. backward_visited_leaf_nodes += 1
  760. if backward_visited_leaf_nodes == len(leaf_tensors):
  761. see_memory_usage("After backward checkpointing code after backward", force=False)
  762. if PROFILE_TIME:
  763. timers('backward').stop()
  764. timers.log(['backward'])
  765. if SYNCHRONIZE:
  766. get_accelerator().synchronize()
  767. with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
  768. outputs = function(*inputs_cuda)
  769. for leaf_tensor in leaf_tensors:
  770. leaf_tensor.register_hook(after_backward_hook)
  771. see_memory_usage("After running forward on the layer", force=False)
  772. if PROFILE_TIME:
  773. timers(FORWARD_GLOBAL_TIMER).stop()
  774. timers.log([FORWARD_GLOBAL_TIMER])
  775. if SYNCHRONIZE:
  776. get_accelerator().synchronize()
  777. all_outputs = []
  778. if torch.is_tensor(outputs):
  779. all_outputs += [outputs]
  780. else:
  781. all_outputs += outputs
  782. if len(all_outputs) == 1:
  783. return all_outputs[0]
  784. else:
  785. return tuple(all_outputs)
  786. def checkpoint(function, *args):
  787. """Checkpoint a model or part of the model.
  788. This has been directly copied from torch.utils.checkpoint. """
  789. all_outputs = []
  790. CheckpointFunction.apply(function, all_outputs, *args)
  791. if len(all_outputs) == 1:
  792. return all_outputs[0]
  793. else:
  794. return tuple(all_outputs)
  795. def partition_activations_in_checkpoint(partition_activation):
  796. global PARTITION_ACTIVATIONS
  797. PARTITION_ACTIVATIONS = partition_activation
  798. if dist.get_rank() == 0:
  799. logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
  800. def set_num_layers(nlayers):
  801. global num_layers
  802. num_layers = nlayers
  803. def reset():
  804. """Resets memory buffers related to contiguous memory optimizations.
  805. Should be called during eval when multiple forward propagations are
  806. computed without any backward propagation that usually clears these
  807. buffers.
  808. Arguments:
  809. None
  810. Return:
  811. None
  812. """
  813. if CONTIGUOUS_CHECKPOINTING:
  814. global data_offsets, size_offsets
  815. global contiguous_data_buffers, contiguous_size_buffers
  816. for buffers in contiguous_data_buffers:
  817. buffers = []
  818. # frees up all the pointers to the checkpoints except for the ones
  819. # stored by save for backward
  820. contiguous_data_buffers = []
  821. contiguous_size_buffers = []
  822. data_offsets = []
  823. size_offsets = []
  824. def _configure_using_config_file(config, mpu=None):
  825. global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  826. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  827. config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
  828. if dist.get_rank() == 0:
  829. logger.info(config.repr())
  830. PARTITION_ACTIVATIONS = config.partition_activations
  831. CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
  832. num_layers = config.number_checkpoints
  833. CPU_CHECKPOINT = config.cpu_checkpointing
  834. SYNCHRONIZE = config.synchronize_checkpoint_boundary
  835. PROFILE_TIME = config.profile
  836. def _configure_defaults():
  837. global mpu, num_layers, deepspeed_checkpointing_enabled
  838. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  839. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  840. PARTITION_ACTIVATIONS = False
  841. CONTIGUOUS_CHECKPOINTING = False
  842. num_layers = False
  843. CPU_CHECKPOINT = False
  844. SYNCHRONIZE = False
  845. PROFILE_TIME = False
  846. deepspeed_checkpointing_enabled = True
  847. def configure(
  848. mpu_,
  849. deepspeed_config=None,
  850. partition_activations=None,
  851. contiguous_checkpointing=None,
  852. num_checkpoints=None,
  853. checkpoint_in_cpu=None,
  854. synchronize=None,
  855. profile=None,
  856. ):
  857. """Configure DeepSpeed Activation Checkpointing.
  858. Arguments:
  859. mpu_: Optional: An object that implements the following methods
  860. get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
  861. deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
  862. configure DeepSpeed Activation Checkpointing
  863. partition_activations: Optional: Partitions activation checkpoint across model parallel
  864. GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
  865. contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
  866. buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
  867. Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
  868. provided
  869. num_checkpoints: Optional: Number of activation checkpoints stored during the forward
  870. propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
  871. Will overwrite deepspeed_config if provided
  872. checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
  873. partition_activation. Default is false. Will overwrite deepspeed_config if provided
  874. synchronize: Optional: Performs get_accelerator().synchronize() at the beginning and end of
  875. each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
  876. By default false. Will overwrite deepspeed_config if provided
  877. profile: Optional: Logs the forward and backward time for each
  878. deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
  879. if provided
  880. Returns:
  881. None
  882. """
  883. global mpu, num_layers, deepspeed_checkpointing_enabled
  884. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  885. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  886. _configure_defaults()
  887. if mpu_ is not None:
  888. mpu = mpu_
  889. if deepspeed_config is not None:
  890. _configure_using_config_file(deepspeed_config, mpu=mpu)
  891. if partition_activations is not None:
  892. PARTITION_ACTIVATIONS = partition_activations
  893. if contiguous_checkpointing is not None:
  894. CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
  895. if num_checkpoints is not None:
  896. num_layers = num_checkpoints
  897. if checkpoint_in_cpu is not None:
  898. CPU_CHECKPOINT = checkpoint_in_cpu
  899. if synchronize is not None:
  900. SYNCHRONIZE = synchronize
  901. if profile is not None:
  902. PROFILE_TIME = profile
  903. if CONTIGUOUS_CHECKPOINTING:
  904. assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config"
  905. if CONTIGUOUS_CHECKPOINTING:
  906. assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
  907. def is_configured():
  908. """True if deepspeed activation checkpointing has been configured
  909. by calling deepspeed.checkpointing.configure, else returns false
  910. Arguments:
  911. None
  912. Return:
  913. True of configured, else False
  914. """
  915. return deepspeed_checkpointing_enabled