bf16_optimizer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from collections import OrderedDict
  5. import torch
  6. import sys
  7. import os
  8. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  9. from deepspeed import comm as dist
  10. from deepspeed.runtime.constants import PIPE_REPLICATED
  11. from deepspeed.runtime import ZeROOptimizer
  12. from packaging import version as pkg_version
  13. from deepspeed.git_version_info import version
  14. from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
  15. align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
  16. is_model_parallel_parameter, see_memory_usage, graph_process)
  17. from deepspeed.utils import link_hp_params, fragment_address
  18. from deepspeed.checkpoint import enable_universal_checkpoint
  19. from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
  20. SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
  21. PARAM_SLICE_MAPPINGS)
  22. setattr(sys.modules[__name__], 'fragment_address', fragment_address)
  23. class BF16_Optimizer(ZeROOptimizer):
  24. def __init__(self,
  25. init_optimizer,
  26. param_names,
  27. mpu=None,
  28. clip_grad=0.0,
  29. norm_type=2,
  30. allgather_bucket_size=5000000000,
  31. dp_process_group=None,
  32. timers=None,
  33. grad_acc_dtype=None,
  34. graph_harvesting=False):
  35. super().__init__()
  36. see_memory_usage('begin bf16_optimizer', force=True)
  37. self.timers = timers
  38. self.optimizer = init_optimizer
  39. self.param_names = param_names
  40. self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
  41. assert grad_acc_dtype in [torch.float32, torch.bfloat16
  42. ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
  43. self.grad_acc_dtype = grad_acc_dtype
  44. self.clip_grad = clip_grad
  45. self.norm_type = norm_type
  46. self.mpu = mpu
  47. self.allgather_bucket_size = int(allgather_bucket_size)
  48. self.dp_process_group = dp_process_group
  49. self.dp_rank = dist.get_rank(group=self.dp_process_group)
  50. self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
  51. # Use torch (un)flatten ops
  52. self.flatten = _flatten_dense_tensors
  53. self.unflatten = _unflatten_dense_tensors
  54. #align nccl all-gather send buffers to 4-bye boundary
  55. self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
  56. # Build BF16/FP32 groups
  57. self.bf16_groups = []
  58. self.bf16_groups_flat = []
  59. self.bf16_partitioned_groups = []
  60. self.fp32_groups_flat_partition = []
  61. # Maintain different fp32 gradients views for convenience
  62. self.fp32_groups_gradients = []
  63. self.fp32_groups_gradient_dict = {}
  64. self.fp32_groups_gradients_flat = []
  65. self.fp32_groups_actual_gradients_flat = []
  66. self.fp32_groups_gradient_flat_partition = []
  67. self.fp32_groups_has_gradients = []
  68. self.group_paddings = []
  69. self.graph_harvesting = graph_harvesting
  70. if self.using_real_optimizer:
  71. self._setup_for_real_optimizer()
  72. see_memory_usage('end bf16_optimizer', force=True)
  73. def _setup_for_real_optimizer(self):
  74. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  75. self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
  76. for i, param_group in enumerate(self.optimizer.param_groups):
  77. see_memory_usage(f'before initializing group {i}', force=True)
  78. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  79. # grab the original list
  80. trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
  81. self.bf16_groups.append(trainable_parameters)
  82. # create flat bf16 params
  83. self.bf16_groups_flat.append(
  84. self._flatten_dense_tensors_aligned(self.bf16_groups[i],
  85. self.nccl_start_alignment_factor * dp_world_size))
  86. # Make bf16 params point to flat tensor storage
  87. self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
  88. flat_tensor=self.bf16_groups_flat[i])
  89. # divide flat weights into equal sized partitions
  90. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  91. bf16_dp_partitions = [
  92. self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
  93. for dp_index in range(dp_world_size)
  94. ]
  95. self.bf16_partitioned_groups.append(bf16_dp_partitions)
  96. # create fp32 params partition
  97. self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
  98. self.fp32_groups_flat_partition[i].requires_grad = True
  99. num_elem_list = [t.numel() for t in self.bf16_groups[i]]
  100. # create fp32 gradients
  101. self.fp32_groups_gradients_flat.append(
  102. torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
  103. # track individual fp32 gradients for entire model
  104. fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
  105. num_elem_list=num_elem_list)
  106. self.fp32_groups_gradients.append(fp32_gradients)
  107. self.fp32_groups_gradient_dict[i] = fp32_gradients
  108. # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
  109. length_without_padding = sum(num_elem_list)
  110. self.fp32_groups_actual_gradients_flat.append(
  111. torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
  112. # flat tensor corresponding to gradient partition
  113. self.fp32_groups_gradient_flat_partition.append(
  114. torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
  115. # track fp32 gradient updates
  116. self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
  117. # Record padding required for alignment
  118. if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
  119. padding = self.bf16_groups_flat[i].numel() - length_without_padding
  120. else:
  121. padding = 0
  122. self.group_paddings.append(padding)
  123. # update optimizer param groups to reference fp32 params partition
  124. param_group['params'] = [self.fp32_groups_flat_partition[i]]
  125. see_memory_usage(f'after initializing group {i}', force=True)
  126. see_memory_usage('before initialize_optimizer', force=True)
  127. self.initialize_optimizer_states()
  128. see_memory_usage('end initialize_optimizer', force=True)
  129. # Need optimizer states initialized before linking lp to optimizer state
  130. self._link_all_hp_params()
  131. self._enable_universal_checkpoint()
  132. self._param_slice_mappings = self._create_param_mapping()
  133. def _enable_universal_checkpoint(self):
  134. for lp_param_group in self.bf16_groups:
  135. enable_universal_checkpoint(param_list=lp_param_group)
  136. def _create_param_mapping(self):
  137. param_mapping = []
  138. for i, _ in enumerate(self.optimizer.param_groups):
  139. param_mapping_per_group = OrderedDict()
  140. for lp in self.bf16_groups[i]:
  141. if lp._hp_mapping is not None:
  142. lp_name = self.param_names[lp]
  143. param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
  144. param_mapping.append(param_mapping_per_group)
  145. return param_mapping
  146. def _link_all_hp_params(self):
  147. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  148. for i, _ in enumerate(self.optimizer.param_groups):
  149. # Link bf16 and fp32 params in partition
  150. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  151. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  152. flat_hp_partition = self.fp32_groups_flat_partition[i]
  153. link_hp_params(lp_param_list=self.bf16_groups[i],
  154. flat_hp_partition=flat_hp_partition,
  155. gradient_dict=self.fp32_groups_gradient_dict,
  156. offload_gradient_dict=None,
  157. use_offload=False,
  158. param_group_index=i,
  159. partition_start=partition_id * partition_size,
  160. partition_size=partition_size,
  161. partition_optimizer_state=self.optimizer.state[flat_hp_partition],
  162. dp_group=self.real_dp_process_group[i])
  163. def initialize_optimizer_states(self):
  164. """Take an optimizer step with zero-valued gradients to allocate internal
  165. optimizer state.
  166. This helps prevent memory fragmentation by allocating optimizer state at the
  167. beginning of training instead of after activations have been allocated.
  168. """
  169. for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
  170. self.fp32_groups_gradient_flat_partition):
  171. # In case of grad acc dtype different than FP32, need to cast to high precision.
  172. param_partition.grad = grad_partition.to(
  173. param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
  174. self.optimizer.step()
  175. if self.grad_acc_dtype is not torch.float32:
  176. for param_partition in self.fp32_groups_flat_partition:
  177. param_partition.grad = None
  178. self.clear_hp_grads()
  179. def _split_flat_tensor(self, flat_tensor, num_elem_list):
  180. assert sum(num_elem_list) <= flat_tensor.numel()
  181. tensor_list = []
  182. offset = 0
  183. for num_elem in num_elem_list:
  184. dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
  185. tensor_list.append(dense_tensor)
  186. offset += num_elem
  187. return tensor_list
  188. def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
  189. updated_params = self.unflatten(flat_tensor, tensor_list)
  190. for p, q in zip(tensor_list, updated_params):
  191. p.data = q.data
  192. def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
  193. return self.flatten(align_dense_tensors(tensor_list, alignment))
  194. @torch.no_grad()
  195. def step(self, closure=None):
  196. if closure is not None:
  197. raise NotImplementedError(f'{self.__class__} does not support closure.')
  198. all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
  199. mpu=self.mpu,
  200. norm_type=self.norm_type,
  201. use_graph=self.graph_harvesting)
  202. self._global_grad_norm = all_groups_norm
  203. assert all_groups_norm > 0.
  204. if self.clip_grad > 0.:
  205. clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
  206. max_norm=self.clip_grad,
  207. global_norm=all_groups_norm,
  208. mpu=self.mpu,
  209. use_graph=self.graph_harvesting)
  210. self.optimizer.step()
  211. self.update_lp_params()
  212. self.clear_hp_grads()
  213. def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
  214. """Perform a backward pass and copy the low-precision gradients to the
  215. high-precision copy.
  216. We copy/accumulate to the high-precision grads now to prevent accumulating in the
  217. bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
  218. The low-precision grads are deallocated during this procedure.
  219. """
  220. self.clear_lp_grads()
  221. loss.backward(**bwd_kwargs)
  222. if update_hp_grads:
  223. self.update_hp_grads(clear_lp_grads=clear_lp_grads)
  224. @torch.no_grad()
  225. def update_hp_grads(self, clear_lp_grads=False):
  226. def _update_hp_grads_func(clear_lp_grads=False):
  227. for i, group in enumerate(self.bf16_groups):
  228. for j, lp in enumerate(group):
  229. if lp.grad is None:
  230. continue
  231. hp_grad = self.fp32_groups_gradients[i][j]
  232. assert hp_grad is not None, \
  233. f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
  234. hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
  235. lp._hp_grad = hp_grad
  236. self.fp32_groups_has_gradients[i][j] = True
  237. # clear gradients
  238. if clear_lp_grads:
  239. lp.grad._zero()
  240. if self.graph_harvesting:
  241. graph_process(False, _update_hp_grads_func, clear_lp_grads)
  242. else:
  243. _update_hp_grads_func(clear_lp_grads)
  244. #cpu op
  245. for i, group in enumerate(self.bf16_groups):
  246. for j, lp in enumerate(group):
  247. if lp.grad is None:
  248. continue
  249. self.fp32_groups_has_gradients[i][j] = True
  250. @torch.no_grad()
  251. def get_grads_for_reduction(self):
  252. return self.fp32_groups_gradients_flat
  253. @torch.no_grad()
  254. def get_grads_for_norm(self, for_clipping=False):
  255. grads = []
  256. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  257. for i, group in enumerate(self.bf16_groups):
  258. for j, lp in enumerate(group):
  259. if not for_clipping:
  260. if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
  261. continue
  262. if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
  263. continue
  264. if not self.fp32_groups_has_gradients[i][j]:
  265. continue
  266. grads.append(self.fp32_groups_gradients[i][j])
  267. return grads
  268. @torch.no_grad()
  269. def update_lp_params(self):
  270. for i, (bf16_partitions,
  271. fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
  272. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  273. bf16_partitions[partition_id].data.copy_(fp32_partition.data)
  274. # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
  275. # if i == 0:
  276. # print_rank_0(f'{fp32_partition[:10]=}', force=True)
  277. all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
  278. partitioned_param_groups=self.bf16_partitioned_groups,
  279. dp_process_group=self.real_dp_process_group,
  280. start_alignment_factor=self.nccl_start_alignment_factor,
  281. allgather_bucket_size=self.allgather_bucket_size)
  282. def clear_hp_grads(self):
  283. for flat_gradients in self.fp32_groups_gradients_flat:
  284. flat_gradients.zero_()
  285. for i, group in enumerate(self.fp32_groups_gradients):
  286. self.fp32_groups_has_gradients[i] = [False] * len(group)
  287. def clear_lp_grads(self):
  288. for group in self.bf16_groups:
  289. for param in group:
  290. if param.grad is not None:
  291. # Using zero_() fixed memory address for graph replay
  292. param.grad.zero_()
  293. def state_dict(self):
  294. state_dict = {}
  295. state_dict[CLIP_GRAD] = self.clip_grad
  296. state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
  297. state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
  298. state_dict[GROUP_PADDINGS] = self.group_paddings
  299. state_dict[PARTITION_COUNT] = self.partition_count
  300. state_dict[DS_VERSION] = version
  301. state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
  302. return state_dict
  303. # Restore base optimizer fp32 weights bfloat16 weights
  304. def _restore_from_bit16_weights(self):
  305. for i, group in enumerate(self.bf16_groups):
  306. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  307. for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
  308. fp32_partition.data.copy_(bf16_partitions[partition_id].data)
  309. def refresh_fp32_params(self):
  310. self._restore_from_bit16_weights()
  311. def load_state_dict(self,
  312. state_dict_list,
  313. checkpoint_folder,
  314. load_optimizer_states=True,
  315. load_from_fp32_weights=False,
  316. load_serial=None):
  317. if checkpoint_folder:
  318. self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
  319. else:
  320. self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
  321. def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
  322. dp_rank = dist.get_rank(group=self.dp_process_group)
  323. current_rank_sd = state_dict_list[dp_rank]
  324. ckpt_version = current_rank_sd.get(DS_VERSION, False)
  325. assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
  326. ckpt_version = pkg_version.parse(ckpt_version)
  327. self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
  328. if load_optimizer_states:
  329. self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
  330. if load_from_fp32_weights:
  331. for current, saved in zip(self.fp32_groups_flat_partition,
  332. current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
  333. src_tensor = _get_padded_tensor(saved, current.numel())
  334. current.data.copy_(src_tensor.data)
  335. if load_optimizer_states:
  336. self._link_all_hp_params()
  337. def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
  338. self._load_hp_checkpoint_state(checkpoint_folder)
  339. @property
  340. def param_groups(self):
  341. """Forward the wrapped optimizer's parameters."""
  342. return self.optimizer.param_groups
  343. def _load_hp_checkpoint_state(self, checkpoint_dir):
  344. checkpoint_dir = os.path.join(checkpoint_dir, "zero")
  345. tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  346. tp_world_size = self.mpu.get_slice_parallel_world_size()
  347. for i, _ in enumerate(self.optimizer.param_groups):
  348. for lp in self.bf16_groups[i]:
  349. if lp._hp_mapping is not None:
  350. #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
  351. lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
  352. tp_world_size)
  353. def _get_padded_tensor(src_tensor, size):
  354. if src_tensor.numel() >= size:
  355. return src_tensor
  356. padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
  357. slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
  358. slice_tensor.data.copy_(src_tensor.data)
  359. return padded_tensor