checkpointing.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Use to partition the activations stored for backward propagation
  6. Therefore reduces the memory consumption
  7. Also implements CPU checkpointing and contiguous memory checkpointing
  8. Reduces memory consumption and memory fragmentation
  9. Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
  10. b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
  11. """
  12. # Parts of the code here are adapted from PyTorch
  13. # repo: https://github.com/pytorch/pytorch
  14. import copy
  15. import torch
  16. import contextlib
  17. from deepspeed import comm as dist
  18. import weakref
  19. import mmap
  20. from torch import _C
  21. from deepspeed.runtime.config import DeepSpeedConfig
  22. from deepspeed.utils import logger
  23. from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
  24. from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
  25. from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
  26. from deepspeed.accelerator import get_accelerator
  27. from deepspeed.runtime import compiler
  28. # DeepSpeed Checkpointing Enabled or Disabled
  29. deepspeed_checkpointing_enabled = False
  30. # MP parameters
  31. mpu = None
  32. mp_rank = None
  33. mp_size = None
  34. mp_group = None
  35. # Model Parameters
  36. num_layers = None
  37. # Checkpointing buffers
  38. contiguous_data_buffers = []
  39. data_offsets = []
  40. contiguous_size_buffers = []
  41. size_offsets = []
  42. timers = None
  43. # optimization flags
  44. PARTITION_ACTIVATIONS = False
  45. CPU_CHECKPOINT = False
  46. CONTIGUOUS_CHECKPOINTING = False
  47. SYNCHRONIZE = False
  48. PROFILE_TIME = False
  49. # Default name for the model parallel rng tracker.
  50. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
  51. transport_stream = None
  52. cuda_device = None
  53. def detach_variable(inputs, device=None):
  54. if isinstance(inputs, tuple):
  55. out = []
  56. for inp in inputs:
  57. if not isinstance(inp, torch.Tensor):
  58. out.append(inp)
  59. continue
  60. requires_grad = inp.requires_grad
  61. if device is not None:
  62. x = inp.to(device=device)
  63. else:
  64. x = inp
  65. x = x.detach()
  66. x.requires_grad = requires_grad
  67. out.append(x)
  68. return tuple(out)
  69. else:
  70. raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
  71. def _set_cuda_rng_state(new_state, device=-1):
  72. """Sets the random number generator state of the current GPU.
  73. Arguments:
  74. new_state (torch.ByteTensor): The desired state
  75. This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
  76. with a single change: the input state is not cloned. Cloning caused
  77. major performance issues for +4 GPU cases.
  78. """
  79. if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
  80. # older PyTorch
  81. def cb():
  82. with get_accelerator().device(device):
  83. _C._cuda_setRNGState(new_state)
  84. else:
  85. # newer PyTorch
  86. if device == -1:
  87. device = torch.device(get_accelerator().device_name())
  88. elif isinstance(device, str):
  89. device = torch.device(device)
  90. elif isinstance(device, int):
  91. device = torch.device(get_accelerator().device_name(), device)
  92. def cb():
  93. idx = device.index
  94. if idx is None:
  95. idx = get_accelerator().current_device()
  96. default_generator = get_accelerator().default_generator(idx)
  97. default_generator.set_state(new_state)
  98. get_accelerator().lazy_call(cb)
  99. class CudaRNGStatesTracker:
  100. """Tracker for the cuda RNG states.
  101. Using the `add` method, a cuda rng state is initialized based on
  102. the input `seed` and is assigned to `name`. Later, by forking the
  103. rng state, we can perform operations and return to our starting
  104. cuda state.
  105. """
  106. def __init__(self):
  107. # Map from a string name to the cuda rng state.
  108. self.states_ = {}
  109. # Seeds are just for book keeping and ensure no seed is set twice.
  110. self.seeds_ = set()
  111. def reset(self):
  112. """Set to the initial state (no tracker)."""
  113. self.states_ = {}
  114. self.seeds_ = set()
  115. def get_states(self):
  116. """Get rng states. Copy the dictionary so we have direct
  117. pointers to the states, not just a pointer to the dictionary."""
  118. return copy.copy(self.states_)
  119. def set_states(self, states):
  120. """Set the rng states. For efficiency purposes, we do not check
  121. the size of seed for compatibility."""
  122. self.states_ = states
  123. def add(self, name, seed):
  124. """Track the rng state."""
  125. # Check seed is not already used.
  126. if seed in self.seeds_:
  127. raise Exception('seed {} already exists'.format(seed))
  128. self.seeds_.add(seed)
  129. # Check that state is not already defined.
  130. if name in self.states_:
  131. raise Exception('cuda rng state {} already exists'.format(name))
  132. # Get the current rng state.
  133. orig_rng_state = get_accelerator().get_rng_state()
  134. # Set the new state and store it.
  135. get_accelerator().manual_seed(seed)
  136. self.states_[name] = get_accelerator().get_rng_state()
  137. # Reset rng state to what it was.
  138. _set_cuda_rng_state(orig_rng_state)
  139. @contextlib.contextmanager
  140. def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
  141. """Fork the cuda rng state, perform operations, and exit with
  142. the original state."""
  143. # Check if we have added the state
  144. if name not in self.states_:
  145. raise Exception('cuda rng state {} is not added'.format(name))
  146. # Store current rng state.
  147. orig_cuda_rng_state = get_accelerator().get_rng_state()
  148. # Set rng state to the desired one
  149. _set_cuda_rng_state(self.states_[name])
  150. # Do the stuff we wanted to do.
  151. try:
  152. yield
  153. finally:
  154. # Update the current rng state for later use.
  155. self.states_[name] = get_accelerator().get_rng_state()
  156. # And set the state to the original state we started with.
  157. _set_cuda_rng_state(orig_cuda_rng_state)
  158. # RNG tracker object.
  159. _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
  160. def get_cuda_rng_tracker():
  161. """Get cuda rng tracker."""
  162. return _CUDA_RNG_STATE_TRACKER
  163. def model_parallel_cuda_manual_seed(seed):
  164. """Initialize model parallel cuda seed.
  165. This function should be called after the model parallel is
  166. initialized. Also, no get_accelerator().manual_seed should be called
  167. after this function. Basically, this is replacement for that
  168. function.
  169. Two set of RNG states are tracked:
  170. default state: This is for data parallelism and is the same among a
  171. set of model parallel GPUs but different across
  172. different model parallel groups. This is used for
  173. example for dropout in the non-model-parallel regions.
  174. model-parallel state: This state is different among a set of model
  175. parallel GPUs, but the same across data parallel
  176. groups. This is used for example for dropout in
  177. model parallel regions.
  178. """
  179. global mpu
  180. tp_rank = bwc_tensor_model_parallel_rank(mpu)
  181. # 2718 is just for fun and any POSITIVE value will work.
  182. offset = seed + 2718
  183. model_parallel_seed = offset + tp_rank
  184. # Data parallel gets the original seed.
  185. data_parallel_seed = seed
  186. if dist.get_rank() == 0:
  187. logger.info(
  188. '> initializing model parallel cuda seeds on global rank {}, '
  189. 'model parallel rank {}, and data parallel rank {} with '
  190. 'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
  191. mpu.get_data_parallel_rank(),
  192. model_parallel_seed, data_parallel_seed), )
  193. _CUDA_RNG_STATE_TRACKER.reset()
  194. # Set the default state.
  195. get_accelerator().manual_seed(data_parallel_seed)
  196. # and model parallel state.
  197. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
  198. def model_parallel_reconfigure_tp_seed(seed):
  199. global mpu
  200. tp_rank = bwc_tensor_model_parallel_rank(mpu)
  201. model_parallel_seed = seed + 2718 + tp_rank
  202. with _CUDA_RNG_STATE_TRACKER.fork():
  203. get_accelerator().manual_seed(model_parallel_seed)
  204. def get_partition_start(item):
  205. global mp_rank, mp_size, mp_group
  206. size = item.numel()
  207. partition_size = size / mp_size
  208. start = partition_size * mp_rank
  209. return int(start)
  210. def get_partition_size(item):
  211. global mp_rank, mp_size, mp_group
  212. size = item.numel()
  213. assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
  214. partition_size = size / mp_size
  215. return int(partition_size)
  216. def gather_partitioned_activations(tensors, device=None):
  217. global mp_rank, mp_size, mp_group
  218. assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
  219. inputs = []
  220. num_args = int(len(tensors) / 2)
  221. for i in range(num_args):
  222. item = tensors[2 * i]
  223. size = tensors[2 * i + 1]
  224. if not is_activation_to_checkpoint(item):
  225. inputs.append(item)
  226. continue
  227. # don't need to do all_gather if model parallel is not enabled
  228. if mp_group is None or mp_size == 1:
  229. item = item.view(list(size.numpy()))
  230. if device is not None:
  231. item = item.to(device)
  232. inputs.append(item)
  233. continue
  234. partition_size = item.numel()
  235. tensor_size = partition_size * mp_size
  236. if device is not None:
  237. flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
  238. else:
  239. flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
  240. part = flat_tensor.narrow(0, partition_size * mp_rank, partition_size)
  241. part.copy_(item)
  242. dist.all_gather_into_tensor(flat_tensor, part, group=mp_group)
  243. input_tensor = flat_tensor.view(list(size.numpy()))
  244. item.data = input_tensor.data
  245. inputs.append(item)
  246. return tuple(inputs)
  247. def extract_tensors(all_objects):
  248. """
  249. Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
  250. The order of tensors and non-tensors is preserved in their respective output groups.
  251. Parameters:
  252. all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
  253. Returns:
  254. tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
  255. """
  256. tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
  257. non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
  258. tensor_flags = [torch.is_tensor(v) for v in all_objects]
  259. if type(all_objects) is tuple:
  260. return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
  261. return tensor_objects, non_tensor_objects, tensor_flags
  262. def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
  263. """
  264. Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
  265. Parameters:
  266. tensor_objects (list/tuple): Tensors to merge.
  267. non_tensor_objects (list/tuple): Non-tensors to merge.
  268. tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
  269. Returns:
  270. tuple: Merge of tensors and non-tensors
  271. """
  272. merged_objects = []
  273. tensor_idx = 0
  274. non_tensor_idx = 0
  275. real_tensor_flags = None
  276. # remove the flags that are assigned to the size of the flattened tensors
  277. if PARTITION_ACTIVATIONS:
  278. real_tensor_flags = []
  279. previous_flag = False
  280. for flag in tensor_flags:
  281. if previous_flag:
  282. previous_flag = False
  283. continue
  284. previous_flag = flag
  285. real_tensor_flags.append(flag)
  286. else:
  287. real_tensor_flags = tensor_flags
  288. for is_tensor in real_tensor_flags:
  289. if is_tensor:
  290. merged_objects.append(tensor_objects[tensor_idx])
  291. tensor_idx += 1
  292. else:
  293. merged_objects.append(non_tensor_objects[non_tensor_idx])
  294. non_tensor_idx += 1
  295. return tuple(merged_objects)
  296. def is_activation_to_checkpoint(item):
  297. """
  298. Is an activation to be checkpointed
  299. """
  300. global mp_size
  301. return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
  302. def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
  303. global contiguous_data_buffers, data_offsets
  304. inputs = []
  305. num_non_fp_tensors = 0
  306. for arg_index, item in enumerate(args):
  307. if not is_activation_to_checkpoint(item):
  308. inputs.append(item)
  309. num_non_fp_tensors += 1
  310. continue
  311. i = arg_index - num_non_fp_tensors
  312. partition_size = get_partition_size(item)
  313. partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
  314. buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
  315. if contiguous_checkpoint:
  316. if i >= len(contiguous_data_buffers):
  317. tensor_list = [
  318. torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
  319. for _ in range(num_layers)
  320. ]
  321. contiguous_data_buffers.append(tensor_list)
  322. data_offsets.append(0)
  323. elif contiguous_data_buffers[i] is None:
  324. tensor_list = [
  325. torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
  326. for _ in range(num_layers)
  327. ]
  328. contiguous_data_buffers[i] = tensor_list
  329. data_offsets[i] = 0
  330. # Because the 'new_empty' returns uninitialized pages,
  331. # the pages need to be populated during the cudaMemcpy time
  332. # which increases the data copy time. To avoid this, we
  333. # pre-populate these pages by simply writing 0 ahead of
  334. # the actual cudaMemcpy operation time. Due to the
  335. # previously launched GPU kernels, there is a small
  336. # window of time here for CPUs to populate pages asynchronously.
  337. contiguous_data_buffers[i][data_offsets[i]].data[range(
  338. 0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
  339. int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
  340. contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
  341. data_offsets[i] = data_offsets[i] + 1
  342. inputs.append(contiguous_partition)
  343. else:
  344. partition = partition.cpu() if CPU_CHECKPOINT else partition
  345. inputs.append(partition)
  346. return inputs
  347. def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
  348. global contiguous_size_buffers, size_offsets
  349. new_args = []
  350. num_non_fp_tensors = 0
  351. for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
  352. size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
  353. if not is_activation_to_checkpoint(arg):
  354. new_args.append(arg)
  355. new_args.append(size)
  356. num_non_fp_tensors += 1
  357. continue
  358. arg.data = torch.empty([], device=arg.device).data
  359. arg.saved_data = inp.data
  360. new_args.append(arg)
  361. i = arg_index - num_non_fp_tensors
  362. if contiguous_checkpoint:
  363. numel = size.numel()
  364. if i >= len(contiguous_size_buffers):
  365. tmp = torch.tensor(())
  366. contiguous_size_buffers.append(
  367. tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
  368. size_offsets.append(0)
  369. elif contiguous_size_buffers[i] is None:
  370. tmp = torch.tensor(())
  371. contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
  372. size_offsets[i] = 0
  373. contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
  374. contiguous_size = contiguous_size.view_as(size)
  375. size_offsets[i] = size_offsets[i] + numel
  376. new_args.append(contiguous_size)
  377. else:
  378. new_args.append(size)
  379. return new_args
  380. def get_cpu_activations_for_backward(args, inputs):
  381. new_args = []
  382. for i, (arg, inp) in enumerate(zip(args, inputs)):
  383. if not is_activation_to_checkpoint(arg):
  384. new_args.append(arg)
  385. continue
  386. arg.data = torch.empty([], device=arg.device).data
  387. arg.saved_data = inp.data
  388. new_args.append(arg)
  389. return new_args
  390. class CheckpointFunction(torch.autograd.Function):
  391. """This function is adapted from torch.utils.checkpoint with
  392. two main changes:
  393. 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
  394. 2) the states in the model parallel tracker are also properly
  395. tracked/set/reset.
  396. 3) Performance activation partitioning, contiguous memory optimization
  397. 4) CPU Checkpointing
  398. 5) Profile forward and backward functions
  399. """
  400. @staticmethod
  401. def forward(ctx, run_function, all_outputs, *args):
  402. global mpu, timers, SYNCHRONIZE, PROFILE_TIME
  403. def save_args_for_backward(*all_args):
  404. tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
  405. ctx.deepspeed_saved_tensors = tensor_args
  406. ctx.non_tensor_args = non_tensor_args
  407. ctx.tensor_flags = tensor_flags
  408. if SYNCHRONIZE:
  409. get_accelerator().synchronize()
  410. if timers is None and PROFILE_TIME:
  411. timers = Timers()
  412. if PROFILE_TIME:
  413. timers(FORWARD_GLOBAL_TIMER).start()
  414. ctx.run_function = run_function
  415. global num_layers
  416. global mp_rank, mp_size, mp_group
  417. global contiguous_data_buffers, contiguous_size_buffers
  418. global data_offsets, size_offsets
  419. if mp_rank is None:
  420. if mpu is not None:
  421. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  422. mp_rank = mpu.get_tensor_model_parallel_rank()
  423. mp_size = mpu.get_tensor_model_parallel_world_size()
  424. mp_group = mpu.get_tensor_model_parallel_group()
  425. else:
  426. mp_rank = mpu.get_model_parallel_rank()
  427. mp_size = mpu.get_model_parallel_world_size()
  428. mp_group = mpu.get_model_parallel_group()
  429. else:
  430. mp_rank = 0
  431. mp_size = 1
  432. mp_group = None
  433. global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
  434. if cuda_device is None:
  435. see_memory_usage("First Forward Beginning", force=False)
  436. if dist.get_rank() == 0:
  437. logger.info(f"Activation Checkpointing Information")
  438. logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
  439. logger.info(
  440. f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
  441. logger.info(f"----Synchronization {SYNCHRONIZE}")
  442. logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
  443. cuda_device = get_accelerator().current_device_name()
  444. transport_stream = get_accelerator().Stream(device=cuda_device)
  445. if PARTITION_ACTIVATIONS:
  446. inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
  447. elif CPU_CHECKPOINT:
  448. inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
  449. # just in case something funky is happening such as reuse of inputs
  450. inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
  451. # Copy the rng states.
  452. ctx.fwd_cpu_rng_state = torch.get_rng_state()
  453. ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
  454. ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  455. see_memory_usage("Before running forward on the layer", force=False)
  456. # ctx.save_for_backward(*args)
  457. with torch.no_grad():
  458. outputs = run_function(*inputs_cuda)
  459. see_memory_usage("After running forward on the layer", force=False)
  460. del inputs_cuda
  461. if PARTITION_ACTIVATIONS:
  462. new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
  463. assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
  464. save_args_for_backward(*new_args)
  465. elif CPU_CHECKPOINT:
  466. new_args = get_cpu_activations_for_backward(args, inputs)
  467. save_args_for_backward(*new_args)
  468. else:
  469. save_args_for_backward(*args)
  470. if PROFILE_TIME:
  471. timers(FORWARD_GLOBAL_TIMER).stop()
  472. timers.log([FORWARD_GLOBAL_TIMER])
  473. if SYNCHRONIZE:
  474. get_accelerator().synchronize()
  475. # Tensors returned from forward() may not be differentiable.
  476. if torch.is_tensor(outputs):
  477. non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
  478. else:
  479. non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
  480. ctx.mark_non_differentiable(*non_grad_outputs)
  481. if torch.is_tensor(outputs):
  482. all_outputs += [outputs]
  483. return outputs
  484. else:
  485. all_outputs += outputs
  486. outputs, _, _ = extract_tensors(all_objects=outputs)
  487. return tuple(outputs)
  488. @staticmethod
  489. def backward(ctx, *grads):
  490. global timers
  491. see_memory_usage("In backward", force=False)
  492. # removing pointers to the contiguous buffer memory
  493. # so that they can be garbage collected once the checkpoints
  494. # have been used
  495. if SYNCHRONIZE:
  496. get_accelerator().synchronize()
  497. if PROFILE_TIME:
  498. timers('backward').start()
  499. if CONTIGUOUS_CHECKPOINTING:
  500. global data_offsets, size_offsets
  501. global contiguous_data_buffers, contiguous_size_buffers
  502. for buffers in contiguous_data_buffers:
  503. buffers = []
  504. # frees up all the pointers to the checkpoints except for the ones
  505. # stored by save for backward
  506. contiguous_data_buffers = []
  507. contiguous_size_buffers = []
  508. data_offsets = []
  509. size_offsets = []
  510. see_memory_usage("In backward checkpointing code", force=False)
  511. if not torch.autograd._is_checkpoint_valid():
  512. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  513. "please use .backward() if possible")
  514. global cuda_device, transport_stream, PARTITION_ACTIVATIONS
  515. # Rebuild deepspeed_saved_tensors
  516. for t in ctx.deepspeed_saved_tensors:
  517. if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None:
  518. t.data = t.saved_data.to(t.device)
  519. t.saved_data = None
  520. if PARTITION_ACTIVATIONS:
  521. # with get_accelerator().stream(transport_stream):
  522. inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
  523. device=cuda_device if CPU_CHECKPOINT else None)
  524. detached_inputs = detach_variable(inputs)
  525. elif CPU_CHECKPOINT:
  526. inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
  527. detached_inputs = detach_variable(inputs)
  528. else:
  529. inputs = ctx.deepspeed_saved_tensors
  530. detached_inputs = detach_variable(inputs)
  531. # Add non tensor input args
  532. detached_inputs = merge_tensors(tensor_objects=detached_inputs,
  533. non_tensor_objects=ctx.non_tensor_args,
  534. tensor_flags=ctx.tensor_flags)
  535. # Store the current states.
  536. bwd_cpu_rng_state = torch.get_rng_state()
  537. bwd_cuda_rng_state = get_accelerator().get_rng_state()
  538. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  539. # Set the states to what it used to be before the forward pass.
  540. torch.set_rng_state(ctx.fwd_cpu_rng_state)
  541. _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
  542. get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
  543. # if PARTITION_ACTIVATIONS:
  544. # current_stream=get_accelerator().current_stream()
  545. # current_stream.wait_stream(transport_stream)
  546. see_memory_usage("In backward checkpointing code before forward", force=False)
  547. with torch.enable_grad():
  548. outputs = ctx.run_function(*detached_inputs)
  549. see_memory_usage("In backward checkpointing code after forward", force=False)
  550. # Set the states back to what it was at the start of this function.
  551. torch.set_rng_state(bwd_cpu_rng_state)
  552. _set_cuda_rng_state(bwd_cuda_rng_state)
  553. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  554. if isinstance(outputs, torch.Tensor):
  555. outputs = (outputs, )
  556. # Filter out non tensor outputs
  557. outputs, _, _ = extract_tensors(all_objects=outputs)
  558. # Construct arguments to autograd.backward().
  559. # This is usually just outputs and grads, but forward() can return tensors that
  560. # are not differentiable.
  561. output_tensors = []
  562. grad_tensors = []
  563. for out, grad in zip(outputs, grads):
  564. if out.requires_grad:
  565. output_tensors.append(out)
  566. grad_tensors.append(grad)
  567. see_memory_usage("In backward checkpointing code before backward", force=False)
  568. torch.autograd.backward(output_tensors, grad_tensors)
  569. # Force clear our stashed tensors to prevent a memory leak in certain scenarios
  570. ctx.deepspeed_saved_tensors = None
  571. ctx.non_tensor_args = None
  572. ctx.tensor_flags = None
  573. see_memory_usage("After backward checkpointing code after backward", force=False)
  574. if PROFILE_TIME:
  575. timers('backward').stop()
  576. timers.log(['backward'])
  577. if SYNCHRONIZE:
  578. get_accelerator().synchronize()
  579. ret_list = [None, None] # first None for ctx
  580. for inp in detached_inputs:
  581. if torch.is_tensor(inp):
  582. ret_list.append(inp.grad)
  583. else:
  584. ret_list.append(None)
  585. return tuple(ret_list)
  586. def non_reentrant_checkpoint(function, *args):
  587. """This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module
  588. This function is aim to solve the back probagation error raised from all input requires no grad.
  589. * has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
  590. * can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`
  591. Main modifications compared to the implementation of torch:
  592. 1. adapt to the signature of `checkpoint` function in this module
  593. 2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
  594. 3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
  595. 4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
  596. 5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
  597. """
  598. global mpu, timers, SYNCHRONIZE, PROFILE_TIME
  599. deepspeed_saved_tensors = None
  600. non_tensor_args = None
  601. tensor_flags = None
  602. def save_args_for_backward(*all_args):
  603. """keep this function to reduce the modification from original implementation"""
  604. nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
  605. tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
  606. deepspeed_saved_tensors = tensor_args
  607. non_tensor_args = non_tensor_args
  608. tensor_flags = tensor_flags
  609. if SYNCHRONIZE:
  610. get_accelerator().synchronize()
  611. if timers is None and PROFILE_TIME:
  612. timers = Timers()
  613. if PROFILE_TIME:
  614. timers(FORWARD_GLOBAL_TIMER).start()
  615. global num_layers
  616. global mp_rank, mp_size, mp_group
  617. global contiguous_data_buffers, contiguous_size_buffers
  618. global data_offsets, size_offsets
  619. if mp_rank is None:
  620. if mpu is not None:
  621. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  622. mp_rank = mpu.get_tensor_model_parallel_rank()
  623. mp_size = mpu.get_tensor_model_parallel_world_size()
  624. mp_group = mpu.get_tensor_model_parallel_group()
  625. else:
  626. mp_rank = mpu.get_model_parallel_rank()
  627. mp_size = mpu.get_model_parallel_world_size()
  628. mp_group = mpu.get_model_parallel_group()
  629. else:
  630. mp_rank = 0
  631. mp_size = 1
  632. mp_group = None
  633. global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
  634. if cuda_device is None:
  635. see_memory_usage("First Forward Beginning", force=False)
  636. if dist.get_rank() == 0:
  637. logger.info(f"Activation Checkpointing Information")
  638. logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
  639. logger.info(
  640. f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
  641. logger.info(f"----Synchronization {SYNCHRONIZE}")
  642. logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
  643. cuda_device = get_accelerator().current_device_name()
  644. transport_stream = get_accelerator().Stream(device=cuda_device)
  645. if PARTITION_ACTIVATIONS:
  646. inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
  647. elif CPU_CHECKPOINT:
  648. inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
  649. # just in case something funky is happening such as reuse of inputs
  650. inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
  651. # Copy the rng states.
  652. fwd_cpu_rng_state = torch.get_rng_state()
  653. fwd_cuda_rng_state = get_accelerator().get_rng_state()
  654. fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  655. if PARTITION_ACTIVATIONS:
  656. new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
  657. assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
  658. save_args_for_backward(*new_args)
  659. elif CPU_CHECKPOINT:
  660. new_args = get_cpu_activations_for_backward(args, inputs)
  661. save_args_for_backward(*new_args)
  662. else:
  663. save_args_for_backward(*args)
  664. class Holder():
  665. """the place holder object used as activations to save memory"""
  666. pass
  667. # weakref seems utilized to discover the tensor deletion before a whole
  668. # forward backward pair loop finished
  669. storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  670. weak_holder_list = []
  671. leaf_tensors = []
  672. backward_visited_leaf_nodes = 0
  673. def checkpoint_pack(tensor_from_forward):
  674. """used to record the activation order in the `weak_holder_list`
  675. the activation order in holder list is consistent between the first forward and recomputing forward.
  676. * the jit compiled forward will break the order consistency *
  677. """
  678. res = Holder()
  679. weak_holder_list.append(weakref.ref(res))
  680. # if this is a leaf tensor, save it for backward progression trace
  681. # leaf tensor used to be input or parameters, which is not activations and
  682. # has no memory overhead
  683. if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
  684. leaf_tensors.append(tensor_from_forward)
  685. return res
  686. def checkpoint_unpack(holder_from_backward):
  687. """retrieve the activations from recompute"""
  688. nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
  689. # if this is the first step of backward probagation, recompute the graph and save
  690. # all the activations with the same order as `checkpoint_pack` does
  691. if len(storage) == 0:
  692. unpack_counter = 0
  693. def replay_pack(tensor_from_replay):
  694. """save recompute activations"""
  695. nonlocal unpack_counter
  696. unpack_counter += 1
  697. if weak_holder_list[unpack_counter - 1]() is None:
  698. return
  699. detached_activations = tensor_from_replay.detach()
  700. storage[weak_holder_list[unpack_counter - 1]()] = detached_activations
  701. return
  702. def replay_unpack(none_value):
  703. """recompute graph need not to backward"""
  704. raise RuntimeError("You are calling backwards on a tensor that is never exposed.")
  705. global timers
  706. see_memory_usage("In backward", force=False)
  707. # removing pointers to the contiguous buffer memory
  708. # so that they can be garbage collected once the checkpoints
  709. # have been used
  710. if SYNCHRONIZE:
  711. get_accelerator().synchronize()
  712. if PROFILE_TIME:
  713. timers('backward').start()
  714. if CONTIGUOUS_CHECKPOINTING:
  715. global data_offsets, size_offsets
  716. global contiguous_data_buffers, contiguous_size_buffers
  717. for buffers in contiguous_data_buffers:
  718. buffers = []
  719. # frees up all the pointers to the checkpoints except for the ones
  720. # stored by save for backward
  721. contiguous_data_buffers = []
  722. contiguous_size_buffers = []
  723. data_offsets = []
  724. size_offsets = []
  725. see_memory_usage("In backward checkpointing code", force=False)
  726. if not torch.autograd._is_checkpoint_valid():
  727. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  728. "please use .backward() if possible")
  729. global cuda_device, transport_stream, PARTITION_ACTIVATIONS
  730. # gather inputs which is partitioned or checkpointed before first forward
  731. if PARTITION_ACTIVATIONS:
  732. # with get_accelerator().stream(transport_stream):
  733. inputs = gather_partitioned_activations(deepspeed_saved_tensors,
  734. device=cuda_device if CPU_CHECKPOINT else None)
  735. detached_inputs = detach_variable(inputs)
  736. elif CPU_CHECKPOINT:
  737. inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
  738. detached_inputs = detach_variable(inputs)
  739. else:
  740. inputs = deepspeed_saved_tensors
  741. detached_inputs = detach_variable(inputs)
  742. # Add non tensor input args
  743. detached_inputs = merge_tensors(tensor_objects=detached_inputs,
  744. non_tensor_objects=non_tensor_args,
  745. tensor_flags=tensor_flags)
  746. # Store the current states.
  747. bwd_cpu_rng_state = torch.get_rng_state()
  748. bwd_cuda_rng_state = get_accelerator().get_rng_state()
  749. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  750. # Set the states to what it used to be before the forward pass.
  751. torch.set_rng_state(fwd_cpu_rng_state)
  752. _set_cuda_rng_state(fwd_cuda_rng_state)
  753. get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)
  754. see_memory_usage("In backward checkpointing code before forward", force=False)
  755. with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
  756. _unused = function(*detached_inputs)
  757. see_memory_usage("In backward checkpointing code after forward", force=False)
  758. # Set the states back to what it was at the start of this function.
  759. torch.set_rng_state(bwd_cpu_rng_state)
  760. _set_cuda_rng_state(bwd_cuda_rng_state)
  761. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  762. deepspeed_saved_tensors = None
  763. non_tensor_args = None
  764. tensor_flags = None
  765. if holder_from_backward not in storage:
  766. raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
  767. " recomputation being triggered in between, this is not currently supported.")
  768. return storage[holder_from_backward]
  769. def after_backward_hook(_nonuse_grads):
  770. """the hook registered to all leaf tensors"""
  771. nonlocal leaf_tensors, backward_visited_leaf_nodes
  772. backward_visited_leaf_nodes += 1
  773. if backward_visited_leaf_nodes == len(leaf_tensors):
  774. see_memory_usage("After backward checkpointing code after backward", force=False)
  775. if PROFILE_TIME:
  776. timers('backward').stop()
  777. timers.log(['backward'])
  778. if SYNCHRONIZE:
  779. get_accelerator().synchronize()
  780. with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
  781. outputs = function(*inputs_cuda)
  782. for leaf_tensor in leaf_tensors:
  783. leaf_tensor.register_hook(after_backward_hook)
  784. see_memory_usage("After running forward on the layer", force=False)
  785. if PROFILE_TIME:
  786. timers(FORWARD_GLOBAL_TIMER).stop()
  787. timers.log([FORWARD_GLOBAL_TIMER])
  788. if SYNCHRONIZE:
  789. get_accelerator().synchronize()
  790. all_outputs = []
  791. if torch.is_tensor(outputs):
  792. all_outputs += [outputs]
  793. else:
  794. all_outputs += outputs
  795. if len(all_outputs) == 1:
  796. return all_outputs[0]
  797. else:
  798. return tuple(all_outputs)
  799. @compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue
  800. def checkpoint(function, *args):
  801. """Checkpoint a model or part of the model.
  802. This has been directly copied from torch.utils.checkpoint. """
  803. all_outputs = []
  804. CheckpointFunction.apply(function, all_outputs, *args)
  805. if len(all_outputs) == 1:
  806. return all_outputs[0]
  807. else:
  808. return tuple(all_outputs)
  809. def partition_activations_in_checkpoint(partition_activation):
  810. global PARTITION_ACTIVATIONS
  811. PARTITION_ACTIVATIONS = partition_activation
  812. if dist.get_rank() == 0:
  813. logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
  814. def set_num_layers(nlayers):
  815. global num_layers
  816. num_layers = nlayers
  817. def reset():
  818. """Resets memory buffers related to contiguous memory optimizations.
  819. Should be called during eval when multiple forward propagations are
  820. computed without any backward propagation that usually clears these
  821. buffers.
  822. Arguments:
  823. None
  824. Return:
  825. None
  826. """
  827. if CONTIGUOUS_CHECKPOINTING:
  828. global data_offsets, size_offsets
  829. global contiguous_data_buffers, contiguous_size_buffers
  830. for buffers in contiguous_data_buffers:
  831. buffers = []
  832. # frees up all the pointers to the checkpoints except for the ones
  833. # stored by save for backward
  834. contiguous_data_buffers = []
  835. contiguous_size_buffers = []
  836. data_offsets = []
  837. size_offsets = []
  838. def _configure_using_config_file(config, mpu=None):
  839. global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  840. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  841. config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
  842. if dist.get_rank() == 0:
  843. logger.info(config.repr())
  844. PARTITION_ACTIVATIONS = config.partition_activations
  845. CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
  846. num_layers = config.number_checkpoints
  847. CPU_CHECKPOINT = config.cpu_checkpointing
  848. SYNCHRONIZE = config.synchronize_checkpoint_boundary
  849. PROFILE_TIME = config.profile
  850. def _configure_defaults():
  851. global mpu, num_layers, deepspeed_checkpointing_enabled
  852. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  853. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  854. PARTITION_ACTIVATIONS = False
  855. CONTIGUOUS_CHECKPOINTING = False
  856. num_layers = False
  857. CPU_CHECKPOINT = False
  858. SYNCHRONIZE = False
  859. PROFILE_TIME = False
  860. deepspeed_checkpointing_enabled = True
  861. def configure(
  862. mpu_,
  863. deepspeed_config=None,
  864. partition_activations=None,
  865. contiguous_checkpointing=None,
  866. num_checkpoints=None,
  867. checkpoint_in_cpu=None,
  868. synchronize=None,
  869. profile=None,
  870. ):
  871. """Configure DeepSpeed Activation Checkpointing.
  872. Arguments:
  873. mpu_: Optional: An object that implements the following methods
  874. get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
  875. deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
  876. configure DeepSpeed Activation Checkpointing
  877. partition_activations: Optional: Partitions activation checkpoint across model parallel
  878. GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
  879. contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
  880. buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
  881. Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
  882. provided
  883. num_checkpoints: Optional: Number of activation checkpoints stored during the forward
  884. propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
  885. Will overwrite deepspeed_config if provided
  886. checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
  887. partition_activation. Default is false. Will overwrite deepspeed_config if provided
  888. synchronize: Optional: Performs get_accelerator().synchronize() at the beginning and end of
  889. each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
  890. By default false. Will overwrite deepspeed_config if provided
  891. profile: Optional: Logs the forward and backward time for each
  892. deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
  893. if provided
  894. Returns:
  895. None
  896. """
  897. global mpu, num_layers, deepspeed_checkpointing_enabled
  898. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  899. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  900. _configure_defaults()
  901. if mpu_ is not None:
  902. mpu = mpu_
  903. if deepspeed_config is not None:
  904. _configure_using_config_file(deepspeed_config, mpu=mpu)
  905. if partition_activations is not None:
  906. PARTITION_ACTIVATIONS = partition_activations
  907. if contiguous_checkpointing is not None:
  908. CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
  909. if num_checkpoints is not None:
  910. num_layers = num_checkpoints
  911. if checkpoint_in_cpu is not None:
  912. CPU_CHECKPOINT = checkpoint_in_cpu
  913. if synchronize is not None:
  914. SYNCHRONIZE = synchronize
  915. if profile is not None:
  916. PROFILE_TIME = profile
  917. if CONTIGUOUS_CHECKPOINTING:
  918. assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config"
  919. if CONTIGUOUS_CHECKPOINTING:
  920. assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
  921. def is_configured():
  922. """True if deepspeed activation checkpointing has been configured
  923. by calling deepspeed.checkpointing.configure, else returns false
  924. Arguments:
  925. None
  926. Return:
  927. True of configured, else False
  928. """
  929. return deepspeed_checkpointing_enabled