checkpointing.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. '''
  2. Copyright (c) Microsoft Corporation
  3. Licensed under the MIT license.
  4. Use to partition the activations stored for backward propagation
  5. Therefore reduces the memory consumption
  6. Also implements CPU checkpointing and contiguous memory checkpointing
  7. Reduces memory consumption and memory fragmentation
  8. Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
  9. b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
  10. '''
  11. # Parts of the code here are adapted from PyTorch
  12. # repo: https://github.com/pytorch/pytorch
  13. import copy
  14. import torch
  15. import contextlib
  16. import torch.distributed as dist
  17. import mmap
  18. from torch import _C
  19. from torch.cuda import _lazy_call, device as device_ctx_manager
  20. from deepspeed.runtime.config import DeepSpeedConfig
  21. from deepspeed.utils import logger
  22. from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
  23. from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
  24. # DeepSpeed Checkpointing Enabled or Disabled
  25. deepspeed_checkpointing_enabled = False
  26. # MP parameters
  27. mpu = None
  28. mp_rank = None
  29. mp_size = None
  30. mp_group = None
  31. # Model Parameters
  32. num_layers = None
  33. # Checkpointing buffers
  34. contiguous_data_buffers = []
  35. data_offsets = []
  36. contiguous_size_buffers = []
  37. size_offsets = []
  38. timers = None
  39. # optimization flags
  40. PARTITION_ACTIVATIONS = False
  41. CPU_CHECKPOINT = False
  42. CONTIGUOUS_CHECKPOINTING = False
  43. SYNCHRONIZE = False
  44. PROFILE_TIME = False
  45. # Default name for the model parallel rng tracker.
  46. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
  47. transport_stream = None
  48. cuda_device = None
  49. def detach_variable(inputs, device=None):
  50. if isinstance(inputs, tuple):
  51. out = []
  52. for inp in inputs:
  53. if not isinstance(inp, torch.Tensor):
  54. out.append(inp)
  55. continue
  56. requires_grad = inp.requires_grad
  57. if device is not None:
  58. x = inp.to(device=device)
  59. else:
  60. x = inp
  61. x = x.detach()
  62. x.requires_grad = requires_grad
  63. out.append(x)
  64. return tuple(out)
  65. else:
  66. raise RuntimeError(
  67. "Only tuple of tensors is supported. Got Unsupported input type: ",
  68. type(inputs).__name__)
  69. def _set_cuda_rng_state(new_state, device=-1):
  70. """Sets the random number generator state of the current GPU.
  71. Arguments:
  72. new_state (torch.ByteTensor): The desired state
  73. This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
  74. with a single change: the input state is not cloned. Cloning caused
  75. major performance issues for +4 GPU cases.
  76. """
  77. if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
  78. # older PyTorch
  79. def cb():
  80. with device_ctx_manager(device):
  81. _C._cuda_setRNGState(new_state)
  82. else:
  83. # newer PyTorch
  84. if device == -1:
  85. device = torch.device('cuda')
  86. elif isinstance(device, str):
  87. device = torch.device(device)
  88. elif isinstance(device, int):
  89. device = torch.device('cuda', device)
  90. def cb():
  91. idx = device.index
  92. if idx is None:
  93. idx = torch.cuda.current_device()
  94. default_generator = torch.cuda.default_generators[idx]
  95. default_generator.set_state(new_state)
  96. _lazy_call(cb)
  97. class CudaRNGStatesTracker:
  98. """Tracker for the cuda RNG states.
  99. Using the `add` method, a cuda rng state is initialized based on
  100. the input `seed` and is assigned to `name`. Later, by forking the
  101. rng state, we can perform operations and return to our starting
  102. cuda state.
  103. """
  104. def __init__(self):
  105. # Map from a string name to the cuda rng state.
  106. self.states_ = {}
  107. # Seeds are just for book keeping and ensure no seed is set twice.
  108. self.seeds_ = set()
  109. def reset(self):
  110. """Set to the initial state (no tracker)."""
  111. self.states_ = {}
  112. self.seeds_ = set()
  113. def get_states(self):
  114. """Get rng states. Copy the dictionary so we have direct
  115. pointers to the states, not just a pointer to the dictionary."""
  116. return copy.copy(self.states_)
  117. def set_states(self, states):
  118. """Set the rng states. For efficiency purposes, we do not check
  119. the size of seed for compatibility."""
  120. self.states_ = states
  121. def add(self, name, seed):
  122. """Track the rng state."""
  123. # Check seed is not already used.
  124. if seed in self.seeds_:
  125. raise Exception('seed {} already exists'.format(seed))
  126. self.seeds_.add(seed)
  127. # Check that state is not already defined.
  128. if name in self.states_:
  129. raise Exception('cuda rng state {} already exists'.format(name))
  130. # Get the current rng state.
  131. orig_rng_state = torch.cuda.get_rng_state()
  132. # Set the new state and store it.
  133. torch.cuda.manual_seed(seed)
  134. self.states_[name] = torch.cuda.get_rng_state()
  135. # Reset rng state to what it was.
  136. _set_cuda_rng_state(orig_rng_state)
  137. @contextlib.contextmanager
  138. def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
  139. """Fork the cuda rng state, perform operations, and exit with
  140. the original state."""
  141. # Check if we have added the state
  142. if name not in self.states_:
  143. raise Exception('cuda rng state {} is not added'.format(name))
  144. # Store current rng state.
  145. orig_cuda_rng_state = torch.cuda.get_rng_state()
  146. # Set rng state to the desired one
  147. _set_cuda_rng_state(self.states_[name])
  148. # Do the stuff we wanted to do.
  149. try:
  150. yield
  151. finally:
  152. # Update the current rng state for later use.
  153. self.states_[name] = torch.cuda.get_rng_state()
  154. # And set the state to the original state we started with.
  155. _set_cuda_rng_state(orig_cuda_rng_state)
  156. # RNG tracker object.
  157. _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
  158. def get_cuda_rng_tracker():
  159. """Get cuda rng tracker."""
  160. return _CUDA_RNG_STATE_TRACKER
  161. def model_parallel_cuda_manual_seed(seed):
  162. """Initialize model parallel cuda seed.
  163. This function should be called after the model parallel is
  164. initialized. Also, no torch.cuda.manual_seed should be called
  165. after this function. Basically, this is replacement for that
  166. function.
  167. Two set of RNG states are tracked:
  168. default state: This is for data parallelism and is the same among a
  169. set of model parallel GPUs but different across
  170. different model parallel groups. This is used for
  171. example for dropout in the non-model-parallel regions.
  172. model-parallel state: This state is different among a set of model
  173. parallel GPUs, but the same across data parallel
  174. groups. This is used for example for dropout in
  175. model parallel regions.
  176. """
  177. global mpu
  178. tp_rank = bwc_tensor_model_parallel_rank(mpu)
  179. # 2718 is just for fun and any POSITIVE value will work.
  180. offset = seed + 2718
  181. model_parallel_seed = offset + tp_rank
  182. # Data parallel gets the original seed.
  183. data_parallel_seed = seed
  184. if torch.distributed.get_rank() == 0:
  185. logger.info(
  186. '> initializing model parallel cuda seeds on global rank {}, '
  187. 'model parallel rank {}, and data parallel rank {} with '
  188. 'model parallel seed: {} and data parallel seed: {}'.format(
  189. torch.distributed.get_rank(),
  190. tp_rank,
  191. mpu.get_data_parallel_rank(),
  192. model_parallel_seed,
  193. data_parallel_seed),
  194. )
  195. _CUDA_RNG_STATE_TRACKER.reset()
  196. # Set the default state.
  197. torch.cuda.manual_seed(data_parallel_seed)
  198. # and model parallel state.
  199. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
  200. def get_partition_start(item):
  201. global mp_rank, mp_size, mp_group
  202. size = item.numel()
  203. partition_size = size / mp_size
  204. start = partition_size * mp_rank
  205. return int(start)
  206. def get_partition_size(item):
  207. global mp_rank, mp_size, mp_group
  208. size = item.numel()
  209. assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
  210. partition_size = size / mp_size
  211. return int(partition_size)
  212. def gather_partitioned_activations(tensors, device=None):
  213. global mp_rank, mp_size, mp_group
  214. assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
  215. inputs = []
  216. num_args = int(len(tensors) / 2)
  217. for i in range(num_args):
  218. item = tensors[2 * i]
  219. size = tensors[2 * i + 1]
  220. if not is_activation_to_checkpoint(item):
  221. inputs.append(item)
  222. continue
  223. partition_size = item.numel()
  224. tensor_size = partition_size * mp_size
  225. if device is not None:
  226. flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
  227. else:
  228. flat_tensor = torch.zeros([tensor_size],
  229. dtype=item.dtype,
  230. device=item.device)
  231. partitions = []
  232. for i in range(mp_size):
  233. part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
  234. if i == mp_rank:
  235. part_i.copy_(item)
  236. partitions.append(part_i)
  237. if mp_group is not None:
  238. dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
  239. input_tensor = flat_tensor.view(list(size.numpy()))
  240. item.data = input_tensor.data
  241. inputs.append(item)
  242. return tuple(inputs)
  243. def extract_tensors(all_objects):
  244. """
  245. Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
  246. The order of tensors and non-tensors is preserved in their respective output groups.
  247. Parameters:
  248. all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
  249. Returns:
  250. tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
  251. """
  252. tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
  253. non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
  254. tensor_flags = [torch.is_tensor(v) for v in all_objects]
  255. if type(all_objects) is tuple:
  256. return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
  257. return tensor_objects, non_tensor_objects, tensor_flags
  258. def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
  259. """
  260. Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
  261. Parameters:
  262. tensor_objects (list/tuple): Tensors to merge.
  263. non_tensor_objects (list/tuple): Non-tensors to merge.
  264. tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
  265. Returns:
  266. tuple: Merge of tensors and non-tensors
  267. """
  268. merged_objects = []
  269. tensor_idx = 0
  270. non_tensor_idx = 0
  271. real_tensor_flags = None
  272. # remove the flags that are assigned to the size of the flattened tensors
  273. if PARTITION_ACTIVATIONS:
  274. real_tensor_flags = []
  275. previous_flag = False
  276. for flag in tensor_flags:
  277. if previous_flag:
  278. previous_flag = False
  279. continue
  280. previous_flag = flag
  281. real_tensor_flags.append(flag)
  282. else:
  283. real_tensor_flags = tensor_flags
  284. for is_tensor in real_tensor_flags:
  285. if is_tensor:
  286. merged_objects.append(tensor_objects[tensor_idx])
  287. tensor_idx += 1
  288. else:
  289. merged_objects.append(non_tensor_objects[non_tensor_idx])
  290. non_tensor_idx += 1
  291. return tuple(merged_objects)
  292. def is_activation_to_checkpoint(item):
  293. """
  294. Is an activation to be checkpointed
  295. """
  296. global mp_size
  297. return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
  298. def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
  299. global contiguous_data_buffers, data_offsets
  300. inputs = []
  301. num_non_fp_tensors = 0
  302. for arg_index, item in enumerate(args):
  303. if not is_activation_to_checkpoint(item):
  304. inputs.append(item)
  305. num_non_fp_tensors += 1
  306. continue
  307. i = arg_index - num_non_fp_tensors
  308. partition_size = get_partition_size(item)
  309. partition = item.detach().contiguous().view(-1).narrow(
  310. 0,
  311. get_partition_start(item),
  312. partition_size).clone()
  313. buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
  314. if contiguous_checkpoint:
  315. if i >= len(contiguous_data_buffers):
  316. tensor_list = [
  317. torch.tensor(()).new_empty([partition_size],
  318. dtype=partition.dtype,
  319. device=buffer_device)
  320. for _ in range(num_layers)
  321. ]
  322. contiguous_data_buffers.append(tensor_list)
  323. data_offsets.append(0)
  324. elif contiguous_data_buffers[i] is None:
  325. tensor_list = [
  326. torch.tensor(()).new_empty([partition_size],
  327. dtype=partition.dtype,
  328. device=buffer_device)
  329. for _ in range(num_layers)
  330. ]
  331. contiguous_data_buffers[i] = tensor_list
  332. data_offsets[i] = 0
  333. # Because the 'new_empty' returns uninitialized pages,
  334. # the pages need to be populated during the cudaMemcpy time
  335. # which increases the data copy time. To avoid this, we
  336. # pre-populate these pages by simply writing 0 ahead of
  337. # the actual cudaMemcpy operation time. Due to the
  338. # previously launched GPU kernels, there is a small
  339. # window of time here for CPUs to populate pages asynchronously.
  340. contiguous_data_buffers[i][data_offsets[i]].data[range(
  341. 0,
  342. contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
  343. int(mmap.PAGESIZE /
  344. contiguous_data_buffers[i][data_offsets[i]].data.element_size())
  345. )] = 0
  346. contiguous_partition = contiguous_data_buffers[i][
  347. data_offsets[i]].data.copy_(partition.data)
  348. data_offsets[i] = data_offsets[i] + 1
  349. inputs.append(contiguous_partition)
  350. else:
  351. partition = partition.cpu() if CPU_CHECKPOINT else partition
  352. inputs.append(partition)
  353. return inputs
  354. def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
  355. global contiguous_size_buffers, size_offsets
  356. new_args = []
  357. num_non_fp_tensors = 0
  358. for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
  359. size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
  360. if not is_activation_to_checkpoint(arg):
  361. new_args.append(arg)
  362. new_args.append(size)
  363. num_non_fp_tensors += 1
  364. continue
  365. arg.data = inp.data
  366. new_args.append(arg)
  367. i = arg_index - num_non_fp_tensors
  368. if contiguous_checkpoint:
  369. numel = size.numel()
  370. if i >= len(contiguous_size_buffers):
  371. tmp = torch.tensor(())
  372. contiguous_size_buffers.append(
  373. tmp.new_empty([numel * num_layers],
  374. dtype=size.dtype,
  375. device=size.device))
  376. size_offsets.append(0)
  377. elif contiguous_size_buffers[i] is None:
  378. tmp = torch.tensor(())
  379. contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
  380. dtype=size.dtype,
  381. device=size.device)
  382. size_offsets[i] = 0
  383. contiguous_size = contiguous_size_buffers[i].narrow(
  384. 0,
  385. size_offsets[i],
  386. numel).data.copy_(size.data)
  387. contiguous_size = contiguous_size.view_as(size)
  388. size_offsets[i] = size_offsets[i] + numel
  389. new_args.append(contiguous_size)
  390. else:
  391. new_args.append(size)
  392. return new_args
  393. def get_cpu_activations_for_backward(args, inputs):
  394. new_args = []
  395. for i, (arg, inp) in enumerate(zip(args, inputs)):
  396. if not is_activation_to_checkpoint(arg):
  397. new_args.append(arg)
  398. continue
  399. arg.data = inp.data
  400. new_args.append(arg)
  401. return new_args
  402. class CheckpointFunction(torch.autograd.Function):
  403. """This function is adapted from torch.utils.checkpoint with
  404. two main changes:
  405. 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
  406. 2) the states in the model parallel tracker are also properly
  407. tracked/set/reset.
  408. 3) Performance activation partitioning, contiguous memory optimization
  409. 4) CPU Checkpointing
  410. 5) Profile forward and backward functions
  411. """
  412. @staticmethod
  413. def forward(ctx, run_function, all_outputs, *args):
  414. global mpu, timers, SYNCHRONIZE, PROFILE_TIME
  415. def save_args_for_backward(*all_args):
  416. tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
  417. ctx.deepspeed_saved_tensors = tensor_args
  418. ctx.non_tensor_args = non_tensor_args
  419. ctx.tensor_flags = tensor_flags
  420. if SYNCHRONIZE:
  421. torch.cuda.synchronize()
  422. if timers is None and PROFILE_TIME:
  423. timers = Timers()
  424. if PROFILE_TIME:
  425. timers('forward').start()
  426. ctx.run_function = run_function
  427. global num_layers
  428. global mp_rank, mp_size, mp_group
  429. global contiguous_data_buffers, contiguous_size_buffers
  430. global data_offsets, size_offsets
  431. if mp_rank is None:
  432. if mpu is not None:
  433. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  434. mp_rank = mpu.get_tensor_model_parallel_rank()
  435. mp_size = mpu.get_tensor_model_parallel_world_size()
  436. mp_group = mpu.get_tensor_model_parallel_group()
  437. else:
  438. mp_rank = mpu.get_model_parallel_rank()
  439. mp_size = mpu.get_model_parallel_world_size()
  440. mp_group = mpu.get_model_parallel_group()
  441. else:
  442. mp_rank = 0
  443. mp_size = 1
  444. mp_group = None
  445. global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
  446. if cuda_device is None:
  447. see_memory_usage("First Forward Beginning", force=False)
  448. if dist.get_rank() == 0:
  449. logger.info(f"Activation Checkpointing Information")
  450. logger.info(
  451. f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}"
  452. )
  453. logger.info(
  454. f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
  455. )
  456. logger.info(f"----Synchronization {SYNCHRONIZE}")
  457. logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
  458. cuda_device = torch.cuda.current_device()
  459. transport_stream = torch.cuda.Stream(device=cuda_device)
  460. if PARTITION_ACTIVATIONS:
  461. inputs = partition_activations(args,
  462. CPU_CHECKPOINT,
  463. CONTIGUOUS_CHECKPOINTING)
  464. elif CPU_CHECKPOINT:
  465. inputs = copy_to_device(args,
  466. device=torch.device('cpu'),
  467. criterion_func=is_activation_to_checkpoint)
  468. # just in case something funky is happening such as reuse of inputs
  469. inputs_cuda = copy_to_device(args,
  470. device=cuda_device,
  471. criterion_func=is_activation_to_checkpoint)
  472. # Copy the rng states.
  473. ctx.fwd_cpu_rng_state = torch.get_rng_state()
  474. ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
  475. ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  476. see_memory_usage("Before running forward on the layer", force=False)
  477. # ctx.save_for_backward(*args)
  478. with torch.no_grad():
  479. outputs = run_function(*inputs_cuda)
  480. see_memory_usage("After running forward on the layer", force=False)
  481. del inputs_cuda
  482. if PARTITION_ACTIVATIONS:
  483. new_args = get_partitioned_activations_for_backward(
  484. args,
  485. inputs,
  486. CONTIGUOUS_CHECKPOINTING)
  487. assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
  488. save_args_for_backward(*new_args)
  489. elif CPU_CHECKPOINT:
  490. new_args = get_cpu_activations_for_backward(args, inputs)
  491. save_args_for_backward(*new_args)
  492. else:
  493. save_args_for_backward(*args)
  494. if PROFILE_TIME:
  495. timers('forward').stop()
  496. timers.log(['forward'])
  497. if SYNCHRONIZE:
  498. torch.cuda.synchronize()
  499. # Tensors returned from forward() may not be differentiable.
  500. if torch.is_tensor(outputs):
  501. non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
  502. else:
  503. non_grad_outputs = [
  504. o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()
  505. ]
  506. ctx.mark_non_differentiable(*non_grad_outputs)
  507. if torch.is_tensor(outputs):
  508. all_outputs += [outputs]
  509. return outputs
  510. else:
  511. all_outputs += outputs
  512. outputs, _, _ = extract_tensors(all_objects=outputs)
  513. return tuple(outputs)
  514. @staticmethod
  515. def backward(ctx, *grads):
  516. global timers
  517. see_memory_usage("In backward", force=False)
  518. # removing pointers to the contiguous buffer memory
  519. # so that they can be garbage collected once the checkpoints
  520. # have been used
  521. if SYNCHRONIZE:
  522. torch.cuda.synchronize()
  523. if PROFILE_TIME:
  524. timers('backward').start()
  525. if CONTIGUOUS_CHECKPOINTING:
  526. global data_offsets, size_offsets
  527. global contiguous_data_buffers, contiguous_size_buffers
  528. for buffers in contiguous_data_buffers:
  529. buffers = []
  530. # frees up all the pointers to the checkpoints except for the ones
  531. # stored by save for backward
  532. contiguous_data_buffers = []
  533. contiguous_size_buffers = []
  534. data_offsets = []
  535. size_offsets = []
  536. see_memory_usage("In backward checkpointing code", force=False)
  537. if not torch.autograd._is_checkpoint_valid():
  538. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  539. "please use .backward() if possible")
  540. global cuda_device, transport_stream, PARTITION_ACTIVATIONS
  541. if PARTITION_ACTIVATIONS:
  542. # with torch.cuda.stream(transport_stream):
  543. inputs = gather_partitioned_activations(
  544. ctx.deepspeed_saved_tensors,
  545. device=cuda_device if CPU_CHECKPOINT else None)
  546. detached_inputs = detach_variable(inputs)
  547. elif CPU_CHECKPOINT:
  548. inputs = move_to_device(ctx.deepspeed_saved_tensors,
  549. cuda_device,
  550. is_activation_to_checkpoint)
  551. detached_inputs = detach_variable(inputs)
  552. else:
  553. inputs = ctx.deepspeed_saved_tensors
  554. detached_inputs = detach_variable(inputs)
  555. # Add non tensor input args
  556. detached_inputs = merge_tensors(tensor_objects=detached_inputs,
  557. non_tensor_objects=ctx.non_tensor_args,
  558. tensor_flags=ctx.tensor_flags)
  559. # Store the current states.
  560. bwd_cpu_rng_state = torch.get_rng_state()
  561. bwd_cuda_rng_state = torch.cuda.get_rng_state()
  562. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  563. # Set the states to what it used to be before the forward pass.
  564. torch.set_rng_state(ctx.fwd_cpu_rng_state)
  565. _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
  566. get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
  567. # if PARTITION_ACTIVATIONS:
  568. # current_stream=torch.cuda.current_stream()
  569. # current_stream.wait_stream(transport_stream)
  570. see_memory_usage("In backward checkpointing code before forward", force=False)
  571. with torch.enable_grad():
  572. outputs = ctx.run_function(*detached_inputs)
  573. see_memory_usage("In backward checkpointing code after forward", force=False)
  574. # Set the states back to what it was at the start of this function.
  575. torch.set_rng_state(bwd_cpu_rng_state)
  576. _set_cuda_rng_state(bwd_cuda_rng_state)
  577. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  578. if isinstance(outputs, torch.Tensor):
  579. outputs = (outputs, )
  580. # Filter out non tensor outputs
  581. outputs, _, _ = extract_tensors(all_objects=outputs)
  582. # Construct arguments to autograd.backward().
  583. # This is usually just outputs and grads, but forward() can return tensors that
  584. # are not differentiable.
  585. output_tensors = []
  586. grad_tensors = []
  587. for out, grad in zip(outputs, grads):
  588. if out.requires_grad:
  589. output_tensors.append(out)
  590. grad_tensors.append(grad)
  591. see_memory_usage("In backward checkpointing code before backward", force=False)
  592. torch.autograd.backward(output_tensors, grad_tensors)
  593. see_memory_usage("After backward checkpointing code after backward", force=False)
  594. if PROFILE_TIME:
  595. timers('backward').stop()
  596. timers.log(['backward'])
  597. if SYNCHRONIZE:
  598. torch.cuda.synchronize()
  599. ret_list = [None, None] # first None for ctx
  600. for inp in detached_inputs:
  601. if torch.is_tensor(inp):
  602. ret_list.append(inp.grad)
  603. else:
  604. ret_list.append(None)
  605. return tuple(ret_list)
  606. def checkpoint(function, *args):
  607. """Checkpoint a model or part of the model.
  608. This has been directly copied from torch.utils.checkpoint. """
  609. all_outputs = []
  610. CheckpointFunction.apply(function, all_outputs, *args)
  611. if len(all_outputs) == 1:
  612. return all_outputs[0]
  613. else:
  614. return tuple(all_outputs)
  615. def partition_activations_in_checkpoint(partition_activation):
  616. global PARTITION_ACTIVATIONS
  617. PARTITION_ACTIVATIONS = partition_activation
  618. if dist.get_rank() == 0:
  619. logger.info(
  620. f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
  621. def set_num_layers(nlayers):
  622. global num_layers
  623. num_layers = nlayers
  624. def reset():
  625. """Resets memory buffers related to contiguous memory optimizations.
  626. Should be called during eval when multiple forward propagations are
  627. computed without any backward propagation that usually clears these
  628. buffers.
  629. Arguments:
  630. None
  631. Return:
  632. None
  633. """
  634. if CONTIGUOUS_CHECKPOINTING:
  635. global data_offsets, size_offsets
  636. global contiguous_data_buffers, contiguous_size_buffers
  637. for buffers in contiguous_data_buffers:
  638. buffers = []
  639. # frees up all the pointers to the checkpoints except for the ones
  640. # stored by save for backward
  641. contiguous_data_buffers = []
  642. contiguous_size_buffers = []
  643. data_offsets = []
  644. size_offsets = []
  645. def _configure_using_config_file(config, mpu=None):
  646. global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  647. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  648. config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
  649. if dist.get_rank() == 0:
  650. logger.info(config.repr())
  651. PARTITION_ACTIVATIONS = config.partition_activations
  652. CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
  653. num_layers = config.number_checkpoints
  654. CPU_CHECKPOINT = config.cpu_checkpointing
  655. SYNCHRONIZE = config.synchronize_checkpoint_boundary
  656. PROFILE_TIME = config.profile
  657. def _configure_defaults():
  658. global mpu, num_layers, deepspeed_checkpointing_enabled
  659. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  660. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  661. PARTITION_ACTIVATIONS = False
  662. CONTIGUOUS_CHECKPOINTING = False
  663. num_layers = False
  664. CPU_CHECKPOINT = False
  665. SYNCHRONIZE = False
  666. PROFILE_TIME = False
  667. deepspeed_checkpointing_enabled = True
  668. def configure(
  669. mpu_,
  670. deepspeed_config=None,
  671. partition_activations=None,
  672. contiguous_checkpointing=None,
  673. num_checkpoints=None,
  674. checkpoint_in_cpu=None,
  675. synchronize=None,
  676. profile=None,
  677. ):
  678. """Configure DeepSpeed Activation Checkpointing.
  679. Arguments:
  680. mpu_: Optional: An object that implements the following methods
  681. get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
  682. deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
  683. configure DeepSpeed Activation Checkpointing
  684. partition_activations: Optional: Partitions activation checkpoint across model parallel
  685. GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
  686. contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
  687. buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
  688. Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
  689. provided
  690. num_checkpoints: Optional: Number of activation checkpoints stored during the forward
  691. propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
  692. Will overwrite deepspeed_config if provided
  693. checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
  694. partition_activation. Default is false. Will overwrite deepspeed_config if provided
  695. synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of
  696. each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
  697. By default false. Will overwrite deepspeed_config if provided
  698. profile: Optional: Logs the forward and backward time for each
  699. deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
  700. if provided
  701. Returns:
  702. None
  703. """
  704. global mpu, num_layers, deepspeed_checkpointing_enabled
  705. global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
  706. CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
  707. _configure_defaults()
  708. if mpu_ is not None:
  709. mpu = mpu_
  710. if deepspeed_config is not None:
  711. _configure_using_config_file(deepspeed_config, mpu=mpu)
  712. if partition_activations is not None:
  713. PARTITION_ACTIVATIONS = partition_activations
  714. if contiguous_checkpointing is not None:
  715. CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
  716. if num_checkpoints is not None:
  717. num_layers = num_checkpoints
  718. if checkpoint_in_cpu is not None:
  719. CPU_CHECKPOINT = checkpoint_in_cpu
  720. if synchronize is not None:
  721. SYNCHRONIZE = synchronize
  722. if profile is not None:
  723. PROFILE_TIME = profile
  724. if CONTIGUOUS_CHECKPOINTING:
  725. assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config"
  726. if CONTIGUOUS_CHECKPOINTING:
  727. assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
  728. def is_configured():
  729. """True if deepspeed activation checkpointing has been configured
  730. by calling deepspeed.checkpointing.configure, else returns false
  731. Arguments:
  732. None
  733. Return:
  734. True of configured, else False
  735. """
  736. return deepspeed_checkpointing_enabled