bf16_optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. """
  4. from collections import OrderedDict
  5. import torch
  6. import sys
  7. import os
  8. from deepspeed import comm as dist
  9. from deepspeed.runtime.constants import PIPE_REPLICATED
  10. from deepspeed.ops.op_builder import UtilsBuilder
  11. from deepspeed.runtime import ZeROOptimizer
  12. from packaging import version as pkg_version
  13. from deepspeed.git_version_info import version
  14. from deepspeed.runtime.utils import (get_global_norm_of_tensors,
  15. clip_tensors_by_global_norm,
  16. DummyOptim,
  17. align_dense_tensors,
  18. all_gather_dp_groups,
  19. bwc_tensor_model_parallel_rank,
  20. is_model_parallel_parameter,
  21. see_memory_usage)
  22. from deepspeed.utils import link_hp_params, fragment_address
  23. from deepspeed.checkpoint import enable_universal_checkpoint
  24. from deepspeed.checkpoint.constants import (DS_VERSION,
  25. PARTITION_COUNT,
  26. BASE_OPTIMIZER_STATE,
  27. SINGLE_PARTITION_OF_FP32_GROUPS,
  28. CLIP_GRAD,
  29. GROUP_PADDINGS,
  30. PARAM_SLICE_MAPPINGS)
  31. setattr(sys.modules[__name__], 'fragment_address', fragment_address)
  32. class BF16_Optimizer(ZeROOptimizer):
  33. def __init__(self,
  34. init_optimizer,
  35. param_names,
  36. mpu=None,
  37. clip_grad=0.0,
  38. norm_type=2,
  39. allgather_bucket_size=5000000000,
  40. dp_process_group=None,
  41. timers=None):
  42. super().__init__()
  43. see_memory_usage('begin bf16_optimizer', force=True)
  44. self.timers = timers
  45. self.optimizer = init_optimizer
  46. self.param_names = param_names
  47. self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
  48. self.clip_grad = clip_grad
  49. self.norm_type = norm_type
  50. self.mpu = mpu
  51. self.allgather_bucket_size = int(allgather_bucket_size)
  52. self.dp_process_group = dp_process_group
  53. self.dp_rank = dist.get_rank(group=self.dp_process_group)
  54. self.real_dp_process_group = [
  55. dp_process_group for i in range(len(self.optimizer.param_groups))
  56. ]
  57. # Load pre-built or JIT compile (un)flatten ops
  58. util_ops = UtilsBuilder().load()
  59. self.flatten = util_ops.flatten
  60. self.unflatten = util_ops.unflatten
  61. #align nccl all-gather send buffers to 4-bye boundary
  62. self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
  63. # Build BF16/FP32 groups
  64. self.bf16_groups = []
  65. self.bf16_groups_flat = []
  66. self.bf16_partitioned_groups = []
  67. self.fp32_groups_flat_partition = []
  68. # Maintain different fp32 gradients views for convenience
  69. self.fp32_groups_gradients = []
  70. self.fp32_groups_gradient_dict = {}
  71. self.fp32_groups_gradients_flat = []
  72. self.fp32_groups_actual_gradients_flat = []
  73. self.fp32_groups_gradient_flat_partition = []
  74. self.fp32_groups_has_gradients = []
  75. self.step_count = 0
  76. self.group_paddings = []
  77. if self.using_real_optimizer:
  78. self._setup_for_real_optimizer()
  79. see_memory_usage('end bf16_optimizer', force=True)
  80. def _setup_for_real_optimizer(self):
  81. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  82. self.partition_count = [
  83. dp_world_size for i in range(len(self.optimizer.param_groups))
  84. ]
  85. for i, param_group in enumerate(self.optimizer.param_groups):
  86. see_memory_usage(f'before initializing group {i}', force=True)
  87. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  88. # grab the original list
  89. self.bf16_groups.append(param_group['params'])
  90. # create flat bf16 params
  91. self.bf16_groups_flat.append(
  92. self._flatten_dense_tensors_aligned(
  93. self.bf16_groups[i],
  94. self.nccl_start_alignment_factor * dp_world_size))
  95. # Make bf16 params point to flat tensor storage
  96. self._update_storage_to_flattened_tensor(
  97. tensor_list=self.bf16_groups[i],
  98. flat_tensor=self.bf16_groups_flat[i])
  99. # divide flat weights into equal sized partitions
  100. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  101. bf16_dp_partitions = [
  102. self.bf16_groups_flat[i].narrow(0,
  103. dp_index * partition_size,
  104. partition_size)
  105. for dp_index in range(dp_world_size)
  106. ]
  107. self.bf16_partitioned_groups.append(bf16_dp_partitions)
  108. # create fp32 params partition
  109. self.fp32_groups_flat_partition.append(
  110. bf16_dp_partitions[partition_id].clone().float().detach())
  111. self.fp32_groups_flat_partition[i].requires_grad = True
  112. num_elem_list = [t.numel() for t in self.bf16_groups[i]]
  113. # create fp32 gradients
  114. self.fp32_groups_gradients_flat.append(
  115. torch.zeros_like(self.bf16_groups_flat[i],
  116. dtype=torch.float32))
  117. # track individual fp32 gradients for entire model
  118. fp32_gradients = self._split_flat_tensor(
  119. flat_tensor=self.fp32_groups_gradients_flat[i],
  120. num_elem_list=num_elem_list)
  121. self.fp32_groups_gradients.append(fp32_gradients)
  122. self.fp32_groups_gradient_dict[i] = fp32_gradients
  123. # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
  124. length_without_padding = sum(num_elem_list)
  125. self.fp32_groups_actual_gradients_flat.append(
  126. torch.narrow(self.fp32_groups_gradients_flat[i],
  127. 0,
  128. 0,
  129. length_without_padding))
  130. # flat tensor corresponding to gradient partition
  131. self.fp32_groups_gradient_flat_partition.append(
  132. torch.narrow(self.fp32_groups_gradients_flat[i],
  133. 0,
  134. partition_id * partition_size,
  135. partition_size))
  136. # track fp32 gradient updates
  137. self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
  138. # Record padding required for alignment
  139. if partition_id == dist.get_world_size(
  140. group=self.real_dp_process_group[i]) - 1:
  141. padding = self.bf16_groups_flat[i].numel() - length_without_padding
  142. else:
  143. padding = 0
  144. self.group_paddings.append(padding)
  145. # update optimizer param groups to reference fp32 params partition
  146. param_group['params'] = [self.fp32_groups_flat_partition[i]]
  147. see_memory_usage(f'after initializing group {i}', force=True)
  148. see_memory_usage('before initialize_optimizer', force=True)
  149. self.initialize_optimizer_states()
  150. see_memory_usage('end initialize_optimizer', force=True)
  151. # Need optimizer states initialized before linking lp to optimizer state
  152. self._link_all_hp_params()
  153. self._enable_universal_checkpoint()
  154. self._param_slice_mappings = self._create_param_mapping()
  155. def _enable_universal_checkpoint(self):
  156. for lp_param_group in self.bf16_groups:
  157. enable_universal_checkpoint(param_list=lp_param_group)
  158. def _create_param_mapping(self):
  159. param_mapping = []
  160. for i, _ in enumerate(self.optimizer.param_groups):
  161. param_mapping_per_group = OrderedDict()
  162. for lp in self.bf16_groups[i]:
  163. if lp._hp_mapping is not None:
  164. lp_name = self.param_names[lp]
  165. param_mapping_per_group[
  166. lp_name] = lp._hp_mapping.get_hp_fragment_address()
  167. param_mapping.append(param_mapping_per_group)
  168. return param_mapping
  169. def _link_all_hp_params(self):
  170. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  171. for i, _ in enumerate(self.optimizer.param_groups):
  172. # Link bf16 and fp32 params in partition
  173. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  174. partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
  175. flat_hp_partition = self.fp32_groups_flat_partition[i]
  176. link_hp_params(
  177. lp_param_list=self.bf16_groups[i],
  178. flat_hp_partition=flat_hp_partition,
  179. gradient_dict=self.fp32_groups_gradient_dict,
  180. offload_gradient_dict=None,
  181. use_offload=False,
  182. param_group_index=i,
  183. partition_start=partition_id * partition_size,
  184. partition_size=partition_size,
  185. partition_optimizer_state=self.optimizer.state[flat_hp_partition],
  186. dp_group=self.real_dp_process_group[i])
  187. def initialize_optimizer_states(self):
  188. """Take an optimizer step with zero-valued gradients to allocate internal
  189. optimizer state.
  190. This helps prevent memory fragmentation by allocating optimizer state at the
  191. beginning of training instead of after activations have been allocated.
  192. """
  193. for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, self.fp32_groups_gradient_flat_partition):
  194. param_partition.grad = grad_partition
  195. self.optimizer.step()
  196. self.clear_hp_grads()
  197. def _split_flat_tensor(self, flat_tensor, num_elem_list):
  198. assert sum(num_elem_list) <= flat_tensor.numel()
  199. tensor_list = []
  200. offset = 0
  201. for num_elem in num_elem_list:
  202. dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
  203. tensor_list.append(dense_tensor)
  204. offset += num_elem
  205. return tensor_list
  206. def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
  207. updated_params = self.unflatten(flat_tensor, tensor_list)
  208. for p, q in zip(tensor_list, updated_params):
  209. p.data = q.data
  210. def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
  211. return self.flatten(align_dense_tensors(tensor_list, alignment))
  212. @torch.no_grad()
  213. def step(self, closure=None):
  214. if closure is not None:
  215. raise NotImplementedError(f'{self.__class__} does not support closure.')
  216. all_groups_norm = get_global_norm_of_tensors(
  217. input_tensors=self.get_grads_for_norm(),
  218. mpu=self.mpu,
  219. norm_type=self.norm_type)
  220. self._global_grad_norm = all_groups_norm
  221. assert all_groups_norm > 0.
  222. if self.clip_grad > 0.:
  223. clip_tensors_by_global_norm(
  224. input_tensors=self.get_grads_for_norm(for_clipping=True),
  225. max_norm=self.clip_grad,
  226. global_norm=all_groups_norm,
  227. mpu=self.mpu)
  228. self.optimizer.step()
  229. self.update_lp_params()
  230. self.clear_hp_grads()
  231. self.step_count += 1
  232. def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
  233. """Perform a backward pass and copy the low-precision gradients to the
  234. high-precision copy.
  235. We copy/accumulate to the high-precision grads now to prevent accumulating in the
  236. bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
  237. The low-precision grads are deallocated during this procedure.
  238. """
  239. self.clear_lp_grads()
  240. loss.backward(**bwd_kwargs)
  241. if update_hp_grads:
  242. self.update_hp_grads(clear_lp_grads=clear_lp_grads)
  243. @torch.no_grad()
  244. def update_hp_grads(self, clear_lp_grads=False):
  245. for i, group in enumerate(self.bf16_groups):
  246. for j, lp in enumerate(group):
  247. if lp.grad is None:
  248. continue
  249. hp_grad = self.fp32_groups_gradients[i][j]
  250. assert hp_grad is not None, \
  251. f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
  252. hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
  253. lp._hp_grad = hp_grad
  254. self.fp32_groups_has_gradients[i][j] = True
  255. # clear gradients
  256. if clear_lp_grads:
  257. lp.grad = None
  258. @torch.no_grad()
  259. def get_grads_for_reduction(self):
  260. return self.fp32_groups_gradients_flat
  261. @torch.no_grad()
  262. def get_grads_for_norm(self, for_clipping=False):
  263. grads = []
  264. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  265. for i, group in enumerate(self.bf16_groups):
  266. for j, lp in enumerate(group):
  267. if not for_clipping:
  268. if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
  269. continue
  270. if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
  271. continue
  272. if not self.fp32_groups_has_gradients[i][j]:
  273. continue
  274. grads.append(self.fp32_groups_gradients[i][j])
  275. return grads
  276. @torch.no_grad()
  277. def update_lp_params(self):
  278. for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
  279. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  280. bf16_partitions[partition_id].data.copy_(fp32_partition.data)
  281. # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
  282. # if i == 0:
  283. # print_rank_0(f'{fp32_partition[:10]=}', force=True)
  284. all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups,
  285. dp_process_group=self.real_dp_process_group,
  286. start_alignment_factor=self.nccl_start_alignment_factor,
  287. allgather_bucket_size=self.allgather_bucket_size)
  288. def clear_hp_grads(self):
  289. for flat_gradients in self.fp32_groups_gradients_flat:
  290. flat_gradients.zero_()
  291. for i, group in enumerate(self.fp32_groups_gradients):
  292. self.fp32_groups_has_gradients[i] = [False] * len(group)
  293. def clear_lp_grads(self):
  294. for group in self.bf16_groups:
  295. for param in group:
  296. param.grad = None
  297. def state_dict(self):
  298. state_dict = {}
  299. state_dict[CLIP_GRAD] = self.clip_grad
  300. state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
  301. state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
  302. state_dict[GROUP_PADDINGS] = self.group_paddings
  303. state_dict[PARTITION_COUNT] = self.partition_count
  304. state_dict[DS_VERSION] = version
  305. state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
  306. return state_dict
  307. # Restore base optimizer fp32 weights bfloat16 weights
  308. def _restore_from_bit16_weights(self):
  309. for i, group in enumerate(self.bf16_groups):
  310. partition_id = dist.get_rank(group=self.real_dp_process_group[i])
  311. for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
  312. fp32_partition.data.copy_(bf16_partitions[partition_id].data)
  313. def refresh_fp32_params(self):
  314. self._restore_from_bit16_weights()
  315. def load_state_dict(self,
  316. state_dict_list,
  317. checkpoint_folder,
  318. load_optimizer_states=True,
  319. load_from_fp32_weights=False):
  320. if checkpoint_folder:
  321. self._load_universal_checkpoint(checkpoint_folder,
  322. load_optimizer_states,
  323. load_from_fp32_weights)
  324. else:
  325. self._load_legacy_checkpoint(state_dict_list,
  326. load_optimizer_states,
  327. load_from_fp32_weights)
  328. def _load_legacy_checkpoint(self,
  329. state_dict_list,
  330. load_optimizer_states=True,
  331. load_from_fp32_weights=False):
  332. dp_rank = dist.get_rank(group=self.dp_process_group)
  333. current_rank_sd = state_dict_list[dp_rank]
  334. ckpt_version = current_rank_sd.get(DS_VERSION, False)
  335. assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
  336. ckpt_version = pkg_version.parse(ckpt_version)
  337. self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
  338. if load_optimizer_states:
  339. self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
  340. if load_from_fp32_weights:
  341. for current, saved in zip(self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
  342. src_tensor = _get_padded_tensor(saved, current.numel())
  343. current.data.copy_(src_tensor.data)
  344. if load_optimizer_states:
  345. self._link_all_hp_params()
  346. def _load_universal_checkpoint(self,
  347. checkpoint_folder,
  348. load_optimizer_states,
  349. load_from_fp32_weights):
  350. self._load_hp_checkpoint_state(checkpoint_folder)
  351. @property
  352. def param_groups(self):
  353. """Forward the wrapped optimizer's parameters."""
  354. return self.optimizer.param_groups
  355. def _load_hp_checkpoint_state(self, checkpoint_dir):
  356. checkpoint_dir = os.path.join(checkpoint_dir, "zero")
  357. tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  358. tp_world_size = self.mpu.get_slice_parallel_world_size()
  359. for i, _ in enumerate(self.optimizer.param_groups):
  360. for lp in self.bf16_groups[i]:
  361. if lp._hp_mapping is not None:
  362. #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
  363. lp.load_hp_checkpoint_state(
  364. os.path.join(checkpoint_dir,
  365. self.param_names[lp]),
  366. tp_rank,
  367. tp_world_size)
  368. def _get_padded_tensor(src_tensor, size):
  369. if src_tensor.numel() >= size:
  370. return src_tensor
  371. padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
  372. slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
  373. slice_tensor.data.copy_(src_tensor.data)
  374. return padded_tensor