tensor_fragment.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from dataclasses import dataclass
  6. from deepspeed import comm as dist
  7. from typing import Dict
  8. @dataclass
  9. class fragment_address:
  10. numel: int
  11. start: int
  12. @dataclass
  13. class tensor_fragment:
  14. lp_fragment: torch.Tensor
  15. lp_fragment_address: fragment_address
  16. hp_fragment: torch.Tensor
  17. hp_fragment_address: fragment_address
  18. optim_fragment: Dict
  19. gradient_dict: Dict
  20. offload_gradient_dict: Dict
  21. use_offload: bool
  22. param_group_index: int
  23. def update_hp(self):
  24. self.hp_fragment.data.copy_(self.lp_fragment.data)
  25. def update_lp(self):
  26. self.lp_fragment.data.copy_(self.hp_fragment.data)
  27. def get_optim_state_fragment(self, key):
  28. if key in self.optim_fragment:
  29. return self.optim_fragment[key]
  30. else:
  31. raise ValueError(f'{key} not found in optimizer state fragment')
  32. def get_hp_fragment_address(self):
  33. return self.hp_fragment_address
  34. def get_optim_state_keys(self):
  35. return list(self.optim_fragment.keys())
  36. def get_full_hp_param(self, optim_state_key=None):
  37. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  38. if self._hp_mapping is not None:
  39. lp_frag_address = self._hp_mapping.lp_fragment_address
  40. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  41. if optim_state_key is None:
  42. hp_fragment = self._hp_mapping.hp_fragment
  43. else:
  44. hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
  45. reduce_fragment.data.copy_(hp_fragment.data)
  46. dist.all_reduce(reduce_buffer, group=self._dp_group)
  47. return reduce_buffer.reshape_as(self)
  48. def get_full_hp_grad(self):
  49. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  50. if self._hp_mapping is not None:
  51. hp_mapping = self._hp_mapping
  52. if hp_mapping.use_offload:
  53. gradient_dict = hp_mapping.offload_gradient_dict
  54. else:
  55. gradient_dict = hp_mapping.gradient_dict
  56. if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
  57. raise ValueError("Gradients are only available immediately after backward and before engine step")
  58. lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
  59. hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
  60. lp_frag_address = self._hp_mapping.lp_fragment_address
  61. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  62. if self.view(-1).shape == hp_grad_fragment.shape:
  63. reduce_buffer.data.copy_(hp_grad_fragment.data)
  64. else:
  65. reduce_fragment.data.copy_(hp_grad_fragment.data)
  66. dist.all_reduce(reduce_buffer, group=self._dp_group)
  67. return reduce_buffer.reshape_as(self)
  68. def safe_get_full_fp32_param(param):
  69. """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
  70. Args:
  71. param (``torch.nn.Parameter``): A model parameter
  72. """
  73. # ZeRO stage 3 param
  74. if hasattr(param, 'ds_id'):
  75. return param._z3_optimizer.get_full_hp_param(param)
  76. # ZeRO stage 1, 2, and bf16_optimizer params
  77. if hasattr(param, '_hp_mapping'):
  78. return param.get_full_hp_param()
  79. return None
  80. def safe_get_full_optimizer_state(param, optim_state_key):
  81. """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  82. Args:
  83. param (``torch.nn.Parameter``): A model parameter
  84. """
  85. # ZeRO stage 3 param
  86. if hasattr(param, 'ds_id'):
  87. return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
  88. # ZeRO stage 1, 2, and bf16_optimizer params
  89. if hasattr(param, '_hp_mapping'):
  90. return param.get_full_hp_param(optim_state_key)
  91. return None
  92. # TODO: Figure out the correct return dtype
  93. def safe_get_full_grad(param):
  94. """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
  95. Args:
  96. param (``torch.nn.Parameter``): A model parameter
  97. """
  98. if param.grad is not None:
  99. return param.grad
  100. # ZeRO stage 3 param
  101. if hasattr(param, 'ds_id'):
  102. return param._z3_optimizer.get_fp32_grad_for_param(param)
  103. # ZeRO stage 1, 2, and bf16_optimizer params
  104. if hasattr(param, '_hp_mapping'):
  105. return param.get_full_hp_grad()
  106. return None
  107. def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  108. param_group_index, partition_start, partition_size, optimizer_state_dict):
  109. lp_end = lp_param.numel() + lp_start
  110. hp_start = partition_start
  111. hp_end = partition_start + partition_size
  112. fragment_start = max(lp_start, hp_start)
  113. fragment_end = min(lp_end, hp_end)
  114. assert fragment_start < fragment_end, \
  115. f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
  116. fragment_numel = fragment_end - fragment_start
  117. hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
  118. hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  119. optim_fragment = {
  120. key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  121. for key, value in optimizer_state_dict.items()
  122. if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
  123. }
  124. lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
  125. lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
  126. return tensor_fragment(lp_fragment=lp_fragment_tensor,
  127. lp_fragment_address=lp_frag_address,
  128. hp_fragment=hp_fragment_tensor,
  129. hp_fragment_address=hp_frag_address,
  130. optim_fragment=optim_fragment,
  131. gradient_dict=gradient_dict,
  132. offload_gradient_dict=offload_gradient_dict,
  133. use_offload=use_offload,
  134. param_group_index=param_group_index)
  135. '''
  136. Logic for lp_param to hp_param mapping
  137. lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
  138. lp [ ][ ][ ][ ][ ] <-------- tensors
  139. flat_lp [ ] <-------- flat lp params
  140. flat_hp [ ] <------------------ flat hp partition on current rank
  141. full_hp [ ] <------- full flat hp params
  142. lp2
  143. full numel = 16
  144. lp_frag
  145. numel = 12
  146. frag_start = 3
  147. frag_end = 15
  148. hp_frag
  149. numel = 12
  150. frag_start = 0
  151. frag_end = 11
  152. hp_frag.copy_(lp_frag)
  153. lp3:
  154. full numel = 4
  155. lp_frag
  156. numel = 4
  157. start = 0
  158. end = 3
  159. hp_frag
  160. numel = 4
  161. start = 12
  162. end = 15
  163. lp4:
  164. full numel = 12
  165. lp_frag
  166. numel = 4
  167. start = 0
  168. end = 3
  169. hp_frag
  170. numel = 4
  171. start = 16
  172. end = 19
  173. Visual depiction of above
  174. lp { }
  175. flat_lp [ ]
  176. flat_hp ( )
  177. flat_lp [ { ( } ) ]
  178. lx hx ly hy
  179. ly-hx
  180. lp { }
  181. flat_lp [ ]
  182. flat_hp ( )
  183. flat_lp [ ( { ) } ]
  184. hx lx hy ly
  185. hy-lx
  186. lp { }
  187. flat_lp [ ]
  188. flat_hp ( )
  189. flat_lp [ ( { } ) ]
  190. hx lx ly hy
  191. ly-lx
  192. lp -> (lx, hy)
  193. flat_hp -> (hx, hy)
  194. '''