bf16_optimizer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from collections import OrderedDict
  5. import torch
  6. import sys
  7. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  8. from deepspeed import comm as dist
  9. from deepspeed.runtime.constants import PIPE_REPLICATED
  10. from deepspeed.runtime.base_optimizer import ZeROOptimizer
  11. from packaging import version as pkg_version
  12. from deepspeed.git_version_info import version
  13. from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
  14. align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
  15. see_memory_usage, graph_process, get_norm_with_moe_layers)
  16. from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
  17. from deepspeed.moe.utils import is_moe_param, is_moe_param_group
  18. from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
  19. from deepspeed.checkpoint import enable_universal_checkpoint
  20. from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
  21. SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
  22. PARAM_SLICE_MAPPINGS)
  23. setattr(sys.modules[__name__], 'fragment_address', fragment_address)
  24. class BF16_Optimizer(ZeROOptimizer):
  25. def __init__(self,
  26. init_optimizer,
  27. param_names,
  28. mpu=None,
  29. clip_grad=0.0,
  30. norm_type=2,
  31. allgather_bucket_size=5000000000,
  32. dp_process_group=None,
  33. timers=None,
  34. grad_acc_dtype=None,
  35. graph_harvesting=False,
  36. immediate_grad_update=False,
  37. has_moe_layers=False):
  38. super().__init__()
  39. see_memory_usage('begin bf16_optimizer', force=True)
  40. self.timers = timers
  41. self.optimizer = init_optimizer
  42. self.param_names = param_names
  43. self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
  44. assert grad_acc_dtype in [torch.float32, torch.bfloat16
  45. ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
  46. self.grad_acc_dtype = grad_acc_dtype
  47. self.immediate_grad_update = immediate_grad_update
  48. self.clip_grad = clip_grad
  49. self.norm_type = norm_type
  50. self.mpu = mpu
  51. self.allgather_bucket_size = int(allgather_bucket_size)
  52. self.dp_process_group = dp_process_group
  53. self.dp_rank = dist.get_rank(group=self.dp_process_group)
  54. self.has_moe_layers = has_moe_layers
  55. self.non_expert_gradients = []
  56. self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
  57. if self.has_moe_layers:
  58. self._configure_moe_settings()
  59. # Use torch (un)flatten ops
  60. self.flatten = _flatten_dense_tensors
  61. self.unflatten = _unflatten_dense_tensors
  62. #align nccl all-gather send buffers to 4-bye boundary
  63. self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
  64. # Build BF16/FP32 groups
  65. self.bf16_groups = []
  66. self.bf16_groups_flat = []
  67. self.bf16_partitioned_groups = []
  68. self.fp32_groups_flat_partition = []
  69. # Maintain different fp32 gradients views for convenience
  70. self.fp32_groups_gradients = []
  71. self.fp32_groups_gradient_dict = {}
  72. self.fp32_groups_gradients_flat = []
  73. self.fp32_groups_actual_gradients_flat = []
  74. self.fp32_groups_gradient_flat_partition = []
  75. self.fp32_groups_has_gradients = []
  76. self.group_paddings = []
  77. self.graph_harvesting = graph_harvesting
  78. if self.using_real_optimizer:
  79. self._setup_for_real_optimizer()
  80. see_memory_usage('end bf16_optimizer', force=True)
  81. def _configure_moe_settings(self):
  82. assert any(
  83. [is_moe_param_group(group) for group in self.optimizer.param_groups]
  84. ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
  85. for i, group in enumerate(self.optimizer.param_groups):
  86. if is_moe_param_group(group):
  87. assert all([is_moe_param(param)
  88. for param in group['params']]), "All params in MoE group must be MoE params"
  89. self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
  90. self.expert_gradients = {}
  91. if self.has_moe_layers:
  92. for key in groups._get_expert_data_parallel_group_dict().keys():
  93. self.expert_gradients[key] = []
  94. def _setup_for_real_optimizer(self):
  95. self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]
  96. for i, param_group in enumerate(self.optimizer.param_groups):
  97. real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
  98. see_memory_usage(f'before initializing group {i}', force=True)
  99. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  100. # grab the original list
  101. trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
  102. self.bf16_groups.append(trainable_parameters)
  103. # create flat bf16 params
  104. self.bf16_groups_flat.append(
  105. self._flatten_dense_tensors_aligned(self.bf16_groups[i],
  106. self.nccl_start_alignment_factor * real_dp_world_size))
  107. # Make bf16 params point to flat tensor storage
  108. self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
  109. flat_tensor=self.bf16_groups_flat[i])
  110. # divide flat weights into equal sized partitions
  111. partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
  112. bf16_dp_partitions = [
  113. self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
  114. for dp_index in range(real_dp_world_size)
  115. ]
  116. self.bf16_partitioned_groups.append(bf16_dp_partitions)
  117. # create fp32 params partition
  118. self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
  119. self.fp32_groups_flat_partition[i].requires_grad = True
  120. num_elem_list = [t.numel() for t in self.bf16_groups[i]]
  121. # create fp32 gradients
  122. fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
  123. self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
  124. if self.has_moe_layers and is_moe_param_group(param_group):
  125. self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
  126. else:
  127. self.non_expert_gradients.append(fp32_flat_buffer)
  128. # track individual fp32 gradients for entire model
  129. fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
  130. num_elem_list=num_elem_list)
  131. self.fp32_groups_gradients.append(fp32_gradients)
  132. self.fp32_groups_gradient_dict[i] = fp32_gradients
  133. # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
  134. length_without_padding = sum(num_elem_list)
  135. self.fp32_groups_actual_gradients_flat.append(
  136. torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
  137. # flat tensor corresponding to gradient partition
  138. self.fp32_groups_gradient_flat_partition.append(
  139. torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
  140. # track fp32 gradient updates
  141. self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
  142. # Record padding required for alignment
  143. if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
  144. padding = self.bf16_groups_flat[i].numel() - length_without_padding
  145. else:
  146. padding = 0
  147. self.group_paddings.append(padding)
  148. # update optimizer param groups to reference fp32 params partition
  149. param_group['params'] = [self.fp32_groups_flat_partition[i]]
  150. see_memory_usage(f'after initializing group {i}', force=True)
  151. see_memory_usage('before initialize_optimizer', force=True)
  152. self.initialize_optimizer_states()
  153. see_memory_usage('end initialize_optimizer', force=True)
  154. if self.immediate_grad_update:
  155. self.create_grad_acc_hooks()
  156. # Need optimizer states initialized before linking lp to optimizer state
  157. self._link_all_hp_params()
  158. self._hp_optimizer_states_linked = False
  159. self._enable_universal_checkpoint()
  160. self._param_slice_mappings = self._create_param_mapping()
  161. def _enable_universal_checkpoint(self):
  162. for lp_param_group in self.bf16_groups:
  163. enable_universal_checkpoint(param_list=lp_param_group)
  164. def _create_param_mapping(self):
  165. param_mapping = []
  166. for i, _ in enumerate(self.optimizer.param_groups):
  167. param_mapping_per_group = OrderedDict()
  168. for lp in self.bf16_groups[i]:
  169. if lp._hp_mapping is not None:
  170. lp_name = self.param_names[lp]
  171. param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
  172. param_mapping.append(param_mapping_per_group)
  173. return param_mapping
  174. def _link_all_hp_params(self):
  175. for i, _ in enumerate(self.optimizer.param_groups):
  176. real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
  177. # Link bf16 and fp32 params in partition
  178. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  179. partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
  180. flat_hp_partition = self.fp32_groups_flat_partition[i]
  181. link_hp_params(lp_param_list=self.bf16_groups[i],
  182. flat_hp_partition=flat_hp_partition,
  183. gradient_dict=self.fp32_groups_gradient_dict,
  184. offload_gradient_dict=None,
  185. use_offload=False,
  186. param_group_index=i,
  187. partition_start=partition_id * partition_size,
  188. partition_size=partition_size,
  189. dp_group=self.real_dp_process_group[i])
  190. def _lazy_init_hp_params_optimizer_state(self):
  191. if not self._hp_optimizer_states_linked:
  192. for i, _ in enumerate(self.optimizer.param_groups):
  193. lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i],
  194. self.optimizer.state)
  195. self._hp_optimizer_states_linked = True
  196. def initialize_optimizer_states(self):
  197. """Take an optimizer step with zero-valued gradients to allocate internal
  198. optimizer state.
  199. This helps prevent memory fragmentation by allocating optimizer state at the
  200. beginning of training instead of after activations have been allocated.
  201. """
  202. for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
  203. self.fp32_groups_gradient_flat_partition):
  204. # In case of grad acc dtype different than FP32, need to cast to high precision.
  205. param_partition.grad = grad_partition.to(
  206. param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
  207. if self.grad_acc_dtype is not torch.float32:
  208. for param_partition in self.fp32_groups_flat_partition:
  209. param_partition.grad = None
  210. self.clear_hp_grads()
  211. def _split_flat_tensor(self, flat_tensor, num_elem_list):
  212. assert sum(num_elem_list) <= flat_tensor.numel()
  213. tensor_list = []
  214. offset = 0
  215. for num_elem in num_elem_list:
  216. dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
  217. tensor_list.append(dense_tensor)
  218. offset += num_elem
  219. return tensor_list
  220. def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
  221. updated_params = self.unflatten(flat_tensor, tensor_list)
  222. for p, q in zip(tensor_list, updated_params):
  223. p.data = q.data
  224. def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
  225. return self.flatten(align_dense_tensors(tensor_list, alignment))
  226. @torch.no_grad()
  227. def step(self, closure=None):
  228. if closure is not None:
  229. raise NotImplementedError(f'{self.__class__} does not support closure.')
  230. non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
  231. non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
  232. mpu=self.mpu,
  233. norm_type=self.norm_type,
  234. use_graph=self.graph_harvesting)
  235. all_groups_norm = non_expert_groups_norm
  236. if self.has_moe_layers:
  237. all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
  238. mpu=self.mpu,
  239. expert_tensors=expert_grads_for_norm,
  240. norm_type=self.norm_type)
  241. self._global_grad_norm = all_groups_norm
  242. assert all_groups_norm > 0.
  243. if self.clip_grad > 0.:
  244. clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
  245. max_norm=self.clip_grad,
  246. global_norm=all_groups_norm,
  247. mpu=self.mpu,
  248. use_graph=self.graph_harvesting)
  249. self.optimizer.step()
  250. # We need to link optimizer state after the first step() call
  251. self._lazy_init_hp_params_optimizer_state()
  252. self.update_lp_params()
  253. self.clear_hp_grads()
  254. def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
  255. """Perform a backward pass and copy the low-precision gradients to the
  256. high-precision copy.
  257. We copy/accumulate to the high-precision grads now to prevent accumulating in the
  258. bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
  259. The low-precision grads are deallocated during this procedure.
  260. """
  261. self.clear_lp_grads()
  262. loss.backward(**bwd_kwargs)
  263. if update_hp_grads:
  264. self.update_hp_grads(clear_lp_grads=clear_lp_grads)
  265. @torch.no_grad()
  266. def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):
  267. if lp.grad is None:
  268. return
  269. hp_grad = self.fp32_groups_gradients[group_idx][param_idx]
  270. assert hp_grad is not None, \
  271. f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]'
  272. hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
  273. lp._hp_grad = hp_grad
  274. self.fp32_groups_has_gradients[group_idx][param_idx] = True
  275. # clear gradients
  276. if clear_lp_grads:
  277. lp.grad.zero_()
  278. @torch.no_grad()
  279. def _update_hp_grads_func(self, clear_lp_grads=False):
  280. for i, group in enumerate(self.bf16_groups):
  281. for j, lp in enumerate(group):
  282. self._update_hp_grad(lp, i, j, clear_lp_grads)
  283. @torch.no_grad()
  284. def update_hp_grads(self, clear_lp_grads=False):
  285. if self.immediate_grad_update:
  286. return
  287. if self.graph_harvesting:
  288. graph_process(False, self._update_hp_grads_func, clear_lp_grads)
  289. else:
  290. self._update_hp_grads_func(clear_lp_grads)
  291. #cpu op
  292. for i, group in enumerate(self.bf16_groups):
  293. for j, lp in enumerate(group):
  294. if lp.grad is None:
  295. continue
  296. self.fp32_groups_has_gradients[i][j] = True
  297. @torch.no_grad()
  298. def get_grads_for_reduction(self):
  299. if self.has_moe_layers:
  300. return self.non_expert_gradients, self.expert_gradients
  301. return self.non_expert_gradients, {}
  302. @torch.no_grad()
  303. def get_grads_for_norm(self, for_clipping=False):
  304. """
  305. Returns:
  306. tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
  307. If for_clipping, return all gradients.
  308. Otherwise, separate and return dict of expert_grad and list of non_expert_grad
  309. """
  310. # (grads, expert_group_name)
  311. expert_grads_for_norm = {}
  312. # grads
  313. non_expert_grads_for_norm = []
  314. all_grads_for_clip = []
  315. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  316. assert len(self.bf16_groups) == len(self.optimizer.param_groups)
  317. for i, group in enumerate(self.bf16_groups):
  318. for j, lp in enumerate(group):
  319. if not for_clipping:
  320. if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
  321. continue
  322. # skip duplicated parameters. perform norm only on cards with tp_rank=0.
  323. # non-duplicated parameters include:
  324. # - Parameters with tp: Use allreducesum of mp_group.
  325. # - Moe Parameters with ep: Use allreducesum of ep_group.
  326. if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
  327. continue
  328. if not self.fp32_groups_has_gradients[i][j]:
  329. continue
  330. if not for_clipping:
  331. param_group = self.optimizer.param_groups[i]
  332. if self.has_moe_layers and is_moe_param_group(param_group):
  333. if param_group['name'] not in expert_grads_for_norm:
  334. expert_grads_for_norm[param_group['name']] = []
  335. expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
  336. else:
  337. non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
  338. else:
  339. all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
  340. if not for_clipping:
  341. return non_expert_grads_for_norm, expert_grads_for_norm
  342. return all_grads_for_clip
  343. @torch.no_grad()
  344. def update_lp_params(self):
  345. for i, (bf16_partitions,
  346. fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
  347. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  348. bf16_partitions[partition_id].data.copy_(fp32_partition.data)
  349. # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
  350. # if i == 0:
  351. # print_rank_0(f'{fp32_partition[:10]=}', force=True)
  352. all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
  353. partitioned_param_groups=self.bf16_partitioned_groups,
  354. dp_process_group=self.real_dp_process_group,
  355. start_alignment_factor=self.nccl_start_alignment_factor,
  356. allgather_bucket_size=self.allgather_bucket_size)
  357. def clear_hp_grads(self):
  358. for flat_gradients in self.fp32_groups_gradients_flat:
  359. flat_gradients.zero_()
  360. for i, group in enumerate(self.fp32_groups_gradients):
  361. self.fp32_groups_has_gradients[i] = [False] * len(group)
  362. def clear_lp_grads(self):
  363. # using zero_() fixed memory address for graph replay
  364. set_to_none = False if self.graph_harvesting else True
  365. zero_grads_list = []
  366. for group in self.bf16_groups:
  367. for param in group:
  368. if set_to_none:
  369. param.grad = None
  370. elif param.grad is not None:
  371. if param.grad.grad_fn is not None:
  372. param.grad.detach_()
  373. zero_grads_list.append(param.grad)
  374. if not set_to_none and len(zero_grads_list) > 0:
  375. torch._foreach_zero_(zero_grads_list)
  376. def state_dict(self):
  377. state_dict = {}
  378. state_dict[CLIP_GRAD] = self.clip_grad
  379. state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
  380. state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
  381. state_dict[GROUP_PADDINGS] = self.group_paddings
  382. state_dict[PARTITION_COUNT] = self.partition_count
  383. state_dict[DS_VERSION] = version
  384. state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
  385. return state_dict
  386. # Restore base optimizer fp32 weights bfloat16 weights
  387. def _restore_from_bit16_weights(self):
  388. for i, group in enumerate(self.bf16_groups):
  389. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  390. for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
  391. fp32_partition.data.copy_(bf16_partitions[partition_id].data)
  392. def refresh_fp32_params(self):
  393. self._restore_from_bit16_weights()
  394. def load_state_dict(self,
  395. state_dict_list,
  396. checkpoint_folder,
  397. load_optimizer_states=True,
  398. load_from_fp32_weights=False,
  399. load_serial=None):
  400. if checkpoint_folder:
  401. self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
  402. else:
  403. self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
  404. def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
  405. dp_rank = dist.get_rank(group=self.dp_process_group)
  406. current_rank_sd = state_dict_list[dp_rank]
  407. ckpt_version = current_rank_sd.get(DS_VERSION, False)
  408. assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
  409. ckpt_version = pkg_version.parse(ckpt_version)
  410. self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
  411. if load_optimizer_states:
  412. print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
  413. self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
  414. if load_from_fp32_weights:
  415. for current, saved in zip(self.fp32_groups_flat_partition,
  416. current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
  417. src_tensor = _get_padded_tensor(saved, current.numel())
  418. current.data.copy_(src_tensor.data)
  419. if load_optimizer_states:
  420. self._link_all_hp_params()
  421. def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
  422. self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)
  423. def _load_global_state(self, sd):
  424. pass
  425. @property
  426. def param_groups(self):
  427. """Forward the wrapped optimizer's parameters."""
  428. return self.optimizer.param_groups
  429. def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
  430. assert self.immediate_grad_update
  431. self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True)
  432. def create_grad_acc_hooks(self):
  433. self.grad_accs = []
  434. for i, param_group in enumerate(self.bf16_groups):
  435. for j, param in enumerate(param_group):
  436. if param.requires_grad:
  437. def wrapper(param, i, j):
  438. param_tmp = param.expand_as(param)
  439. grad_acc = param_tmp.grad_fn.next_functions[0][0]
  440. def accumulate_hp_grads_and_remove_lp(*notneeded):
  441. self.accumulate_hp_grads_and_remove_lp(param, i, j)
  442. grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)
  443. self.grad_accs.append(grad_acc)
  444. wrapper(param, i, j)
  445. def _get_padded_tensor(src_tensor, size):
  446. if src_tensor.numel() >= size:
  447. return src_tensor
  448. padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
  449. slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
  450. slice_tensor.data.copy_(src_tensor.data)
  451. return padded_tensor