bf16_optimizer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from collections import OrderedDict
  5. import torch
  6. import sys
  7. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  8. from deepspeed import comm as dist
  9. from deepspeed.runtime.constants import PIPE_REPLICATED
  10. from deepspeed.runtime.base_optimizer import ZeROOptimizer
  11. from packaging import version as pkg_version
  12. from deepspeed.git_version_info import version
  13. from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
  14. align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
  15. see_memory_usage, graph_process, get_norm_with_moe_layers)
  16. from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
  17. from deepspeed.moe.utils import is_moe_param, is_moe_param_group
  18. from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
  19. from deepspeed.checkpoint import enable_universal_checkpoint
  20. from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
  21. SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
  22. PARAM_SLICE_MAPPINGS)
  23. setattr(sys.modules[__name__], 'fragment_address', fragment_address)
  24. def print_rank_0(message, debug=False, force=False):
  25. if dist.get_rank() == 0 and (debug or force):
  26. print(message)
  27. class BF16_Optimizer(ZeROOptimizer):
  28. def __init__(self,
  29. init_optimizer,
  30. param_names,
  31. mpu=None,
  32. clip_grad=0.0,
  33. norm_type=2,
  34. allgather_bucket_size=5000000000,
  35. dp_process_group=None,
  36. timers=None,
  37. grad_acc_dtype=None,
  38. graph_harvesting=False,
  39. immediate_grad_update=False,
  40. has_moe_layers=False):
  41. super().__init__()
  42. see_memory_usage('begin bf16_optimizer', force=True)
  43. self.timers = timers
  44. self.optimizer = init_optimizer
  45. self.param_names = param_names
  46. self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
  47. assert grad_acc_dtype in [torch.float32, torch.bfloat16
  48. ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
  49. self.grad_acc_dtype = grad_acc_dtype
  50. self.immediate_grad_update = immediate_grad_update
  51. self.clip_grad = clip_grad
  52. self.norm_type = norm_type
  53. self.mpu = mpu
  54. self.allgather_bucket_size = int(allgather_bucket_size)
  55. self.dp_process_group = dp_process_group
  56. self.dp_rank = dist.get_rank(group=self.dp_process_group)
  57. self.has_moe_layers = has_moe_layers
  58. self.non_expert_gradients = []
  59. self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
  60. if self.has_moe_layers:
  61. self._configure_moe_settings()
  62. # Use torch (un)flatten ops
  63. self.flatten = _flatten_dense_tensors
  64. self.unflatten = _unflatten_dense_tensors
  65. #align nccl all-gather send buffers to 4-bye boundary
  66. self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
  67. # Build BF16/FP32 groups
  68. self.bf16_groups = []
  69. self.bf16_groups_flat = []
  70. self.bf16_partitioned_groups = []
  71. self.fp32_groups_flat_partition = []
  72. # Maintain different fp32 gradients views for convenience
  73. self.fp32_groups_gradients = []
  74. self.fp32_groups_gradient_dict = {}
  75. self.fp32_groups_gradients_flat = []
  76. self.fp32_groups_actual_gradients_flat = []
  77. self.fp32_groups_gradient_flat_partition = []
  78. self.fp32_groups_has_gradients = []
  79. self.group_paddings = []
  80. self.graph_harvesting = graph_harvesting
  81. if self.using_real_optimizer:
  82. self._setup_for_real_optimizer()
  83. see_memory_usage('end bf16_ optimizer', force=True)
  84. def destroy(self):
  85. for i, _ in enumerate(self.optimizer.param_groups):
  86. for p in self.bf16_groups[i]:
  87. if getattr(p, '_hp_mapping', None):
  88. p._hp_mapping = None
  89. for hook in self._grad_acc_hooks:
  90. hook.remove()
  91. print_rank_0("Removed grad acc hooks")
  92. def _configure_moe_settings(self):
  93. assert any(
  94. [is_moe_param_group(group) for group in self.optimizer.param_groups]
  95. ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
  96. for i, group in enumerate(self.optimizer.param_groups):
  97. if is_moe_param_group(group):
  98. assert all([is_moe_param(param)
  99. for param in group['params']]), "All params in MoE group must be MoE params"
  100. self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
  101. self.expert_gradients = {}
  102. if self.has_moe_layers:
  103. for key in groups._get_expert_data_parallel_group_dict().keys():
  104. self.expert_gradients[key] = []
  105. def _setup_for_real_optimizer(self):
  106. self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]
  107. for i, param_group in enumerate(self.optimizer.param_groups):
  108. real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
  109. see_memory_usage(f'before initializing group {i}', force=True)
  110. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  111. # grab the original list
  112. trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
  113. self.bf16_groups.append(trainable_parameters)
  114. # create flat bf16 params
  115. self.bf16_groups_flat.append(
  116. self._flatten_dense_tensors_aligned(self.bf16_groups[i],
  117. self.nccl_start_alignment_factor * real_dp_world_size))
  118. # Make bf16 params point to flat tensor storage
  119. self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
  120. flat_tensor=self.bf16_groups_flat[i])
  121. # divide flat weights into equal sized partitions
  122. partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
  123. bf16_dp_partitions = [
  124. self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
  125. for dp_index in range(real_dp_world_size)
  126. ]
  127. self.bf16_partitioned_groups.append(bf16_dp_partitions)
  128. # create fp32 params partition
  129. self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
  130. self.fp32_groups_flat_partition[i].requires_grad = True
  131. num_elem_list = [t.numel() for t in self.bf16_groups[i]]
  132. # create fp32 gradients
  133. fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
  134. self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
  135. if self.has_moe_layers and is_moe_param_group(param_group):
  136. self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
  137. else:
  138. self.non_expert_gradients.append(fp32_flat_buffer)
  139. # track individual fp32 gradients for entire model
  140. fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
  141. num_elem_list=num_elem_list)
  142. self.fp32_groups_gradients.append(fp32_gradients)
  143. self.fp32_groups_gradient_dict[i] = fp32_gradients
  144. # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
  145. length_without_padding = sum(num_elem_list)
  146. self.fp32_groups_actual_gradients_flat.append(
  147. torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
  148. # flat tensor corresponding to gradient partition
  149. self.fp32_groups_gradient_flat_partition.append(
  150. torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
  151. # track fp32 gradient updates
  152. self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
  153. # Record padding required for alignment
  154. if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
  155. padding = self.bf16_groups_flat[i].numel() - length_without_padding
  156. else:
  157. padding = 0
  158. self.group_paddings.append(padding)
  159. # update optimizer param groups to reference fp32 params partition
  160. param_group['params'] = [self.fp32_groups_flat_partition[i]]
  161. see_memory_usage(f'after initializing group {i}', force=True)
  162. see_memory_usage('before initialize_optimizer', force=True)
  163. self.initialize_optimizer_states()
  164. see_memory_usage('end initialize_optimizer', force=True)
  165. self._grad_acc_hooks = []
  166. if self.immediate_grad_update:
  167. self.create_grad_acc_hooks()
  168. # Need optimizer states initialized before linking lp to optimizer state
  169. self._link_all_hp_params()
  170. self._hp_optimizer_states_linked = False
  171. self._enable_universal_checkpoint()
  172. self._param_slice_mappings = self._create_param_mapping()
  173. def _enable_universal_checkpoint(self):
  174. for lp_param_group in self.bf16_groups:
  175. enable_universal_checkpoint(param_list=lp_param_group)
  176. def _create_param_mapping(self):
  177. param_mapping = []
  178. for i, _ in enumerate(self.optimizer.param_groups):
  179. param_mapping_per_group = OrderedDict()
  180. for lp in self.bf16_groups[i]:
  181. if lp._hp_mapping is not None:
  182. lp_name = self.param_names[lp]
  183. param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
  184. param_mapping.append(param_mapping_per_group)
  185. return param_mapping
  186. def _link_all_hp_params(self):
  187. for i, _ in enumerate(self.optimizer.param_groups):
  188. real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
  189. # Link bf16 and fp32 params in partition
  190. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  191. partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
  192. flat_hp_partition = self.fp32_groups_flat_partition[i]
  193. link_hp_params(lp_param_list=self.bf16_groups[i],
  194. flat_hp_partition=flat_hp_partition,
  195. gradient_dict=self.fp32_groups_gradient_dict,
  196. offload_gradient_dict=None,
  197. use_offload=False,
  198. param_group_index=i,
  199. partition_start=partition_id * partition_size,
  200. partition_size=partition_size,
  201. dp_group=self.real_dp_process_group[i])
  202. def _lazy_init_hp_params_optimizer_state(self):
  203. if not self._hp_optimizer_states_linked:
  204. for i, _ in enumerate(self.optimizer.param_groups):
  205. lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i],
  206. self.optimizer.state)
  207. self._hp_optimizer_states_linked = True
  208. def initialize_optimizer_states(self):
  209. """Take an optimizer step with zero-valued gradients to allocate internal
  210. optimizer state.
  211. This helps prevent memory fragmentation by allocating optimizer state at the
  212. beginning of training instead of after activations have been allocated.
  213. """
  214. for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
  215. self.fp32_groups_gradient_flat_partition):
  216. # In case of grad acc dtype different than FP32, need to cast to high precision.
  217. param_partition.grad = grad_partition.to(
  218. param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
  219. if self.grad_acc_dtype is not torch.float32:
  220. for param_partition in self.fp32_groups_flat_partition:
  221. param_partition.grad = None
  222. self.clear_hp_grads()
  223. def _split_flat_tensor(self, flat_tensor, num_elem_list):
  224. assert sum(num_elem_list) <= flat_tensor.numel()
  225. tensor_list = []
  226. offset = 0
  227. for num_elem in num_elem_list:
  228. dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
  229. tensor_list.append(dense_tensor)
  230. offset += num_elem
  231. return tensor_list
  232. def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
  233. updated_params = self.unflatten(flat_tensor, tensor_list)
  234. for p, q in zip(tensor_list, updated_params):
  235. p.data = q.data
  236. def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
  237. return self.flatten(align_dense_tensors(tensor_list, alignment))
  238. @torch.no_grad()
  239. def step(self, closure=None):
  240. if closure is not None:
  241. raise NotImplementedError(f'{self.__class__} does not support closure.')
  242. non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
  243. non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
  244. mpu=self.mpu,
  245. norm_type=self.norm_type,
  246. use_graph=self.graph_harvesting)
  247. all_groups_norm = non_expert_groups_norm
  248. if self.has_moe_layers:
  249. all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
  250. mpu=self.mpu,
  251. expert_tensors=expert_grads_for_norm,
  252. norm_type=self.norm_type)
  253. self._global_grad_norm = all_groups_norm
  254. assert all_groups_norm > 0.
  255. if self.clip_grad > 0.:
  256. clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
  257. max_norm=self.clip_grad,
  258. global_norm=all_groups_norm,
  259. mpu=self.mpu,
  260. use_graph=self.graph_harvesting)
  261. self.optimizer.step()
  262. # We need to link optimizer state after the first step() call
  263. self._lazy_init_hp_params_optimizer_state()
  264. self.update_lp_params()
  265. self.clear_hp_grads()
  266. def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
  267. """Perform a backward pass and copy the low-precision gradients to the
  268. high-precision copy.
  269. We copy/accumulate to the high-precision grads now to prevent accumulating in the
  270. bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
  271. The low-precision grads are deallocated during this procedure.
  272. """
  273. self.clear_lp_grads()
  274. loss.backward(**bwd_kwargs)
  275. if update_hp_grads:
  276. self.update_hp_grads(clear_lp_grads=clear_lp_grads)
  277. @torch.no_grad()
  278. def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):
  279. if lp.grad is None:
  280. return
  281. hp_grad = self.fp32_groups_gradients[group_idx][param_idx]
  282. assert hp_grad is not None, \
  283. f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]'
  284. hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
  285. lp._hp_grad = hp_grad
  286. self.fp32_groups_has_gradients[group_idx][param_idx] = True
  287. # clear gradients
  288. if clear_lp_grads:
  289. lp.grad.zero_()
  290. @torch.no_grad()
  291. def _update_hp_grads_func(self, clear_lp_grads=False):
  292. for i, group in enumerate(self.bf16_groups):
  293. for j, lp in enumerate(group):
  294. self._update_hp_grad(lp, i, j, clear_lp_grads)
  295. @torch.no_grad()
  296. def update_hp_grads(self, clear_lp_grads=False):
  297. if self.immediate_grad_update:
  298. return
  299. if self.graph_harvesting:
  300. graph_process(False, self._update_hp_grads_func, clear_lp_grads)
  301. else:
  302. self._update_hp_grads_func(clear_lp_grads)
  303. #cpu op
  304. for i, group in enumerate(self.bf16_groups):
  305. for j, lp in enumerate(group):
  306. if lp.grad is None:
  307. continue
  308. self.fp32_groups_has_gradients[i][j] = True
  309. @torch.no_grad()
  310. def get_grads_for_reduction(self):
  311. if self.has_moe_layers:
  312. return self.non_expert_gradients, self.expert_gradients
  313. return self.non_expert_gradients, {}
  314. @torch.no_grad()
  315. def get_grads_for_norm(self, for_clipping=False):
  316. """
  317. Returns:
  318. tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
  319. If for_clipping, return all gradients.
  320. Otherwise, separate and return dict of expert_grad and list of non_expert_grad
  321. """
  322. # (grads, expert_group_name)
  323. expert_grads_for_norm = {}
  324. # grads
  325. non_expert_grads_for_norm = []
  326. all_grads_for_clip = []
  327. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  328. assert len(self.bf16_groups) == len(self.optimizer.param_groups)
  329. for i, group in enumerate(self.bf16_groups):
  330. for j, lp in enumerate(group):
  331. if not for_clipping:
  332. if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
  333. continue
  334. # skip duplicated parameters. perform norm only on cards with tp_rank=0.
  335. # non-duplicated parameters include:
  336. # - Parameters with tp: Use allreducesum of mp_group.
  337. # - Moe Parameters with ep: Use allreducesum of ep_group.
  338. if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
  339. continue
  340. if not self.fp32_groups_has_gradients[i][j]:
  341. continue
  342. if not for_clipping:
  343. param_group = self.optimizer.param_groups[i]
  344. if self.has_moe_layers and is_moe_param_group(param_group):
  345. if param_group['name'] not in expert_grads_for_norm:
  346. expert_grads_for_norm[param_group['name']] = []
  347. expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
  348. else:
  349. non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
  350. else:
  351. all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
  352. if not for_clipping:
  353. return non_expert_grads_for_norm, expert_grads_for_norm
  354. return all_grads_for_clip
  355. @torch.no_grad()
  356. def update_lp_params(self):
  357. for i, (bf16_partitions,
  358. fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
  359. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  360. bf16_partitions[partition_id].data.copy_(fp32_partition.data)
  361. # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
  362. # if i == 0:
  363. # print_rank_0(f'{fp32_partition[:10]=}', force=True)
  364. all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
  365. partitioned_param_groups=self.bf16_partitioned_groups,
  366. dp_process_group=self.real_dp_process_group,
  367. start_alignment_factor=self.nccl_start_alignment_factor,
  368. allgather_bucket_size=self.allgather_bucket_size)
  369. def clear_hp_grads(self):
  370. for flat_gradients in self.fp32_groups_gradients_flat:
  371. flat_gradients.zero_()
  372. for i, group in enumerate(self.fp32_groups_gradients):
  373. self.fp32_groups_has_gradients[i] = [False] * len(group)
  374. def clear_lp_grads(self):
  375. # using zero_() fixed memory address for graph replay
  376. set_to_none = False if self.graph_harvesting else True
  377. zero_grads_list = []
  378. for group in self.bf16_groups:
  379. for param in group:
  380. if set_to_none:
  381. param.grad = None
  382. elif param.grad is not None:
  383. if param.grad.grad_fn is not None:
  384. param.grad.detach_()
  385. zero_grads_list.append(param.grad)
  386. if not set_to_none and len(zero_grads_list) > 0:
  387. torch._foreach_zero_(zero_grads_list)
  388. def state_dict(self):
  389. state_dict = {}
  390. state_dict[CLIP_GRAD] = self.clip_grad
  391. state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
  392. state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
  393. state_dict[GROUP_PADDINGS] = self.group_paddings
  394. state_dict[PARTITION_COUNT] = self.partition_count
  395. state_dict[DS_VERSION] = version
  396. state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
  397. return state_dict
  398. # Restore base optimizer fp32 weights bfloat16 weights
  399. def _restore_from_bit16_weights(self):
  400. for i, group in enumerate(self.bf16_groups):
  401. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  402. for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
  403. fp32_partition.data.copy_(bf16_partitions[partition_id].data)
  404. def refresh_fp32_params(self):
  405. self._restore_from_bit16_weights()
  406. def load_state_dict(self,
  407. state_dict_list,
  408. checkpoint_folder,
  409. load_optimizer_states=True,
  410. load_from_fp32_weights=False,
  411. load_serial=None,
  412. param_shapes=None):
  413. if checkpoint_folder:
  414. self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
  415. else:
  416. self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
  417. def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
  418. dp_rank = dist.get_rank(group=self.dp_process_group)
  419. current_rank_sd = state_dict_list[dp_rank]
  420. ckpt_version = current_rank_sd.get(DS_VERSION, False)
  421. assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
  422. ckpt_version = pkg_version.parse(ckpt_version)
  423. self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
  424. if load_optimizer_states:
  425. print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
  426. self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
  427. if load_from_fp32_weights:
  428. for current, saved in zip(self.fp32_groups_flat_partition,
  429. current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
  430. src_tensor = _get_padded_tensor(saved, current.numel())
  431. current.data.copy_(src_tensor.data)
  432. if load_optimizer_states:
  433. self._link_all_hp_params()
  434. def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
  435. self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)
  436. def _load_global_state(self, sd):
  437. pass
  438. @property
  439. def param_groups(self):
  440. """Forward the wrapped optimizer's parameters."""
  441. return self.optimizer.param_groups
  442. def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
  443. assert self.immediate_grad_update
  444. self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True)
  445. def create_grad_acc_hooks(self):
  446. self.grad_accs = []
  447. for i, param_group in enumerate(self.bf16_groups):
  448. for j, param in enumerate(param_group):
  449. if param.requires_grad:
  450. def wrapper(param, i, j):
  451. param_tmp = param.expand_as(param)
  452. grad_acc = param_tmp.grad_fn.next_functions[0][0]
  453. def accumulate_hp_grads_and_remove_lp(*notneeded):
  454. self.accumulate_hp_grads_and_remove_lp(param, i, j)
  455. self._grad_acc_hooks.append(grad_acc.register_hook(accumulate_hp_grads_and_remove_lp))
  456. self.grad_accs.append(grad_acc)
  457. wrapper(param, i, j)
  458. def _get_padded_tensor(src_tensor, size):
  459. if src_tensor.numel() >= size:
  460. return src_tensor
  461. padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
  462. slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
  463. slice_tensor.data.copy_(src_tensor.data)
  464. return padded_tensor