bf16_optimizer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from collections import OrderedDict
  5. import torch
  6. import sys
  7. import os
  8. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  9. from deepspeed import comm as dist
  10. from deepspeed.runtime.constants import PIPE_REPLICATED
  11. from deepspeed.runtime import ZeROOptimizer
  12. from packaging import version as pkg_version
  13. from deepspeed.git_version_info import version
  14. from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
  15. align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
  16. is_model_parallel_parameter, see_memory_usage)
  17. from deepspeed.utils import link_hp_params, fragment_address
  18. from deepspeed.checkpoint import enable_universal_checkpoint
  19. from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
  20. SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
  21. PARAM_SLICE_MAPPINGS)
  22. setattr(sys.modules[__name__], 'fragment_address', fragment_address)
  23. class BF16_Optimizer(ZeROOptimizer):
  24. def __init__(self,
  25. init_optimizer,
  26. param_names,
  27. mpu=None,
  28. clip_grad=0.0,
  29. norm_type=2,
  30. allgather_bucket_size=5000000000,
  31. dp_process_group=None,
  32. timers=None):
  33. super().__init__()
  34. see_memory_usage('begin bf16_optimizer', force=True)
  35. self.timers = timers
  36. self.optimizer = init_optimizer
  37. self.param_names = param_names
  38. self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
  39. self.clip_grad = clip_grad
  40. self.norm_type = norm_type
  41. self.mpu = mpu
  42. self.allgather_bucket_size = int(allgather_bucket_size)
  43. self.dp_process_group = dp_process_group
  44. self.dp_rank = dist.get_rank(group=self.dp_process_group)
  45. self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
  46. # Use torch (un)flatten ops
  47. self.flatten = _flatten_dense_tensors
  48. self.unflatten = _unflatten_dense_tensors
  49. #align nccl all-gather send buffers to 4-bye boundary
  50. self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
  51. # Build BF16/FP32 groups
  52. self.bf16_groups = []
  53. self.bf16_groups_flat = []
  54. self.bf16_partitioned_groups = []
  55. self.fp32_groups_flat_partition = []
  56. # Maintain different fp32 gradients views for convenience
  57. self.fp32_groups_gradients = []
  58. self.fp32_groups_gradient_dict = {}
  59. self.fp32_groups_gradients_flat = []
  60. self.fp32_groups_actual_gradients_flat = []
  61. self.fp32_groups_gradient_flat_partition = []
  62. self.fp32_groups_has_gradients = []
  63. self.step_count = 0
  64. self.group_paddings = []
  65. if self.using_real_optimizer:
  66. self._setup_for_real_optimizer()
  67. see_memory_usage('end bf16_optimizer', force=True)
  68. def _setup_for_real_optimizer(self):
  69. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  70. self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
  71. for i, param_group in enumerate(self.optimizer.param_groups):
  72. see_memory_usage(f'before initializing group {i}', force=True)
  73. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  74. # grab the original list
  75. trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
  76. self.bf16_groups.append(trainable_parameters)
  77. # create flat bf16 params
  78. self.bf16_groups_flat.append(
  79. self._flatten_dense_tensors_aligned(self.bf16_groups[i],
  80. self.nccl_start_alignment_factor * dp_world_size))
  81. # Make bf16 params point to flat tensor storage
  82. self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
  83. flat_tensor=self.bf16_groups_flat[i])
  84. # divide flat weights into equal sized partitions
  85. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  86. bf16_dp_partitions = [
  87. self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
  88. for dp_index in range(dp_world_size)
  89. ]
  90. self.bf16_partitioned_groups.append(bf16_dp_partitions)
  91. # create fp32 params partition
  92. self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
  93. self.fp32_groups_flat_partition[i].requires_grad = True
  94. num_elem_list = [t.numel() for t in self.bf16_groups[i]]
  95. # create fp32 gradients
  96. self.fp32_groups_gradients_flat.append(torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32))
  97. # track individual fp32 gradients for entire model
  98. fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
  99. num_elem_list=num_elem_list)
  100. self.fp32_groups_gradients.append(fp32_gradients)
  101. self.fp32_groups_gradient_dict[i] = fp32_gradients
  102. # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
  103. length_without_padding = sum(num_elem_list)
  104. self.fp32_groups_actual_gradients_flat.append(
  105. torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
  106. # flat tensor corresponding to gradient partition
  107. self.fp32_groups_gradient_flat_partition.append(
  108. torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
  109. # track fp32 gradient updates
  110. self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
  111. # Record padding required for alignment
  112. if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
  113. padding = self.bf16_groups_flat[i].numel() - length_without_padding
  114. else:
  115. padding = 0
  116. self.group_paddings.append(padding)
  117. # update optimizer param groups to reference fp32 params partition
  118. param_group['params'] = [self.fp32_groups_flat_partition[i]]
  119. see_memory_usage(f'after initializing group {i}', force=True)
  120. see_memory_usage('before initialize_optimizer', force=True)
  121. self.initialize_optimizer_states()
  122. see_memory_usage('end initialize_optimizer', force=True)
  123. # Need optimizer states initialized before linking lp to optimizer state
  124. self._link_all_hp_params()
  125. self._enable_universal_checkpoint()
  126. self._param_slice_mappings = self._create_param_mapping()
  127. def _enable_universal_checkpoint(self):
  128. for lp_param_group in self.bf16_groups:
  129. enable_universal_checkpoint(param_list=lp_param_group)
  130. def _create_param_mapping(self):
  131. param_mapping = []
  132. for i, _ in enumerate(self.optimizer.param_groups):
  133. param_mapping_per_group = OrderedDict()
  134. for lp in self.bf16_groups[i]:
  135. if lp._hp_mapping is not None:
  136. lp_name = self.param_names[lp]
  137. param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
  138. param_mapping.append(param_mapping_per_group)
  139. return param_mapping
  140. def _link_all_hp_params(self):
  141. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  142. for i, _ in enumerate(self.optimizer.param_groups):
  143. # Link bf16 and fp32 params in partition
  144. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  145. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  146. flat_hp_partition = self.fp32_groups_flat_partition[i]
  147. link_hp_params(lp_param_list=self.bf16_groups[i],
  148. flat_hp_partition=flat_hp_partition,
  149. gradient_dict=self.fp32_groups_gradient_dict,
  150. offload_gradient_dict=None,
  151. use_offload=False,
  152. param_group_index=i,
  153. partition_start=partition_id * partition_size,
  154. partition_size=partition_size,
  155. partition_optimizer_state=self.optimizer.state[flat_hp_partition],
  156. dp_group=self.real_dp_process_group[i])
  157. def initialize_optimizer_states(self):
  158. """Take an optimizer step with zero-valued gradients to allocate internal
  159. optimizer state.
  160. This helps prevent memory fragmentation by allocating optimizer state at the
  161. beginning of training instead of after activations have been allocated.
  162. """
  163. for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
  164. self.fp32_groups_gradient_flat_partition):
  165. param_partition.grad = grad_partition
  166. self.optimizer.step()
  167. self.clear_hp_grads()
  168. def _split_flat_tensor(self, flat_tensor, num_elem_list):
  169. assert sum(num_elem_list) <= flat_tensor.numel()
  170. tensor_list = []
  171. offset = 0
  172. for num_elem in num_elem_list:
  173. dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
  174. tensor_list.append(dense_tensor)
  175. offset += num_elem
  176. return tensor_list
  177. def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
  178. updated_params = self.unflatten(flat_tensor, tensor_list)
  179. for p, q in zip(tensor_list, updated_params):
  180. p.data = q.data
  181. def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
  182. return self.flatten(align_dense_tensors(tensor_list, alignment))
  183. @torch.no_grad()
  184. def step(self, closure=None):
  185. if closure is not None:
  186. raise NotImplementedError(f'{self.__class__} does not support closure.')
  187. all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
  188. mpu=self.mpu,
  189. norm_type=self.norm_type)
  190. self._global_grad_norm = all_groups_norm
  191. assert all_groups_norm > 0.
  192. if self.clip_grad > 0.:
  193. clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
  194. max_norm=self.clip_grad,
  195. global_norm=all_groups_norm,
  196. mpu=self.mpu)
  197. self.optimizer.step()
  198. self.update_lp_params()
  199. self.clear_hp_grads()
  200. self.step_count += 1
  201. def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
  202. """Perform a backward pass and copy the low-precision gradients to the
  203. high-precision copy.
  204. We copy/accumulate to the high-precision grads now to prevent accumulating in the
  205. bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
  206. The low-precision grads are deallocated during this procedure.
  207. """
  208. self.clear_lp_grads()
  209. loss.backward(**bwd_kwargs)
  210. if update_hp_grads:
  211. self.update_hp_grads(clear_lp_grads=clear_lp_grads)
  212. @torch.no_grad()
  213. def update_hp_grads(self, clear_lp_grads=False):
  214. for i, group in enumerate(self.bf16_groups):
  215. for j, lp in enumerate(group):
  216. if lp.grad is None:
  217. continue
  218. hp_grad = self.fp32_groups_gradients[i][j]
  219. assert hp_grad is not None, \
  220. f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
  221. hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
  222. lp._hp_grad = hp_grad
  223. self.fp32_groups_has_gradients[i][j] = True
  224. # clear gradients
  225. if clear_lp_grads:
  226. lp.grad = None
  227. @torch.no_grad()
  228. def get_grads_for_reduction(self):
  229. return self.fp32_groups_gradients_flat
  230. @torch.no_grad()
  231. def get_grads_for_norm(self, for_clipping=False):
  232. grads = []
  233. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  234. for i, group in enumerate(self.bf16_groups):
  235. for j, lp in enumerate(group):
  236. if not for_clipping:
  237. if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
  238. continue
  239. if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
  240. continue
  241. if not self.fp32_groups_has_gradients[i][j]:
  242. continue
  243. grads.append(self.fp32_groups_gradients[i][j])
  244. return grads
  245. @torch.no_grad()
  246. def update_lp_params(self):
  247. for i, (bf16_partitions,
  248. fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
  249. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  250. bf16_partitions[partition_id].data.copy_(fp32_partition.data)
  251. # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
  252. # if i == 0:
  253. # print_rank_0(f'{fp32_partition[:10]=}', force=True)
  254. all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups,
  255. dp_process_group=self.real_dp_process_group,
  256. start_alignment_factor=self.nccl_start_alignment_factor,
  257. allgather_bucket_size=self.allgather_bucket_size)
  258. def clear_hp_grads(self):
  259. for flat_gradients in self.fp32_groups_gradients_flat:
  260. flat_gradients.zero_()
  261. for i, group in enumerate(self.fp32_groups_gradients):
  262. self.fp32_groups_has_gradients[i] = [False] * len(group)
  263. def clear_lp_grads(self):
  264. for group in self.bf16_groups:
  265. for param in group:
  266. param.grad = None
  267. def state_dict(self):
  268. state_dict = {}
  269. state_dict[CLIP_GRAD] = self.clip_grad
  270. state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
  271. state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
  272. state_dict[GROUP_PADDINGS] = self.group_paddings
  273. state_dict[PARTITION_COUNT] = self.partition_count
  274. state_dict[DS_VERSION] = version
  275. state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
  276. return state_dict
  277. # Restore base optimizer fp32 weights bfloat16 weights
  278. def _restore_from_bit16_weights(self):
  279. for i, group in enumerate(self.bf16_groups):
  280. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  281. for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
  282. fp32_partition.data.copy_(bf16_partitions[partition_id].data)
  283. def refresh_fp32_params(self):
  284. self._restore_from_bit16_weights()
  285. def load_state_dict(self,
  286. state_dict_list,
  287. checkpoint_folder,
  288. load_optimizer_states=True,
  289. load_from_fp32_weights=False):
  290. if checkpoint_folder:
  291. self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
  292. else:
  293. self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
  294. def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
  295. dp_rank = dist.get_rank(group=self.dp_process_group)
  296. current_rank_sd = state_dict_list[dp_rank]
  297. ckpt_version = current_rank_sd.get(DS_VERSION, False)
  298. assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
  299. ckpt_version = pkg_version.parse(ckpt_version)
  300. self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
  301. if load_optimizer_states:
  302. self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
  303. if load_from_fp32_weights:
  304. for current, saved in zip(self.fp32_groups_flat_partition,
  305. current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
  306. src_tensor = _get_padded_tensor(saved, current.numel())
  307. current.data.copy_(src_tensor.data)
  308. if load_optimizer_states:
  309. self._link_all_hp_params()
  310. def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
  311. self._load_hp_checkpoint_state(checkpoint_folder)
  312. @property
  313. def param_groups(self):
  314. """Forward the wrapped optimizer's parameters."""
  315. return self.optimizer.param_groups
  316. def _load_hp_checkpoint_state(self, checkpoint_dir):
  317. checkpoint_dir = os.path.join(checkpoint_dir, "zero")
  318. tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  319. tp_world_size = self.mpu.get_slice_parallel_world_size()
  320. for i, _ in enumerate(self.optimizer.param_groups):
  321. for lp in self.bf16_groups[i]:
  322. if lp._hp_mapping is not None:
  323. #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
  324. lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
  325. tp_world_size)
  326. def _get_padded_tensor(src_tensor, size):
  327. if src_tensor.numel() >= size:
  328. return src_tensor
  329. padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
  330. slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
  331. slice_tensor.data.copy_(src_tensor.data)
  332. return padded_tensor