tensor_fragment.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. """
  4. import torch
  5. from dataclasses import dataclass
  6. from deepspeed import comm as dist
  7. @dataclass
  8. class fragment_address:
  9. numel: int
  10. start: int
  11. @dataclass
  12. class tensor_fragment:
  13. lp_fragment: torch.Tensor
  14. lp_fragment_address: fragment_address
  15. hp_fragment: torch.Tensor
  16. hp_fragment_address: fragment_address
  17. optim_fragment: {}
  18. gradient_dict: {}
  19. offload_gradient_dict: {}
  20. use_offload: bool
  21. param_group_index: int
  22. def update_hp(self):
  23. self.hp_fragment.data.copy_(self.lp_fragment.data)
  24. def update_lp(self):
  25. self.lp_fragment.data.copy_(self.hp_fragment.data)
  26. def get_optim_state_fragment(self, key):
  27. if key in self.optim_fragment:
  28. return self.optim_fragment[key]
  29. else:
  30. raise ValueError(f'{key} not found in optimizer state fragment')
  31. def get_hp_fragment_address(self):
  32. return self.hp_fragment_address
  33. def get_optim_state_keys(self):
  34. return list(self.optim_fragment.keys())
  35. def get_full_hp_param(self, optim_state_key=None):
  36. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  37. if self._hp_mapping is not None:
  38. lp_frag_address = self._hp_mapping.lp_fragment_address
  39. reduce_fragment = torch.narrow(reduce_buffer,
  40. 0,
  41. lp_frag_address.start,
  42. lp_frag_address.numel)
  43. if optim_state_key is None:
  44. hp_fragment = self._hp_mapping.hp_fragment
  45. else:
  46. hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
  47. reduce_fragment.data.copy_(hp_fragment.data)
  48. dist.all_reduce(reduce_buffer, group=self._dp_group)
  49. return reduce_buffer.reshape_as(self)
  50. def get_full_hp_grad(self):
  51. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  52. if self._hp_mapping is not None:
  53. hp_mapping = self._hp_mapping
  54. if hp_mapping.use_offload:
  55. gradient_dict = hp_mapping.offload_gradient_dict
  56. else:
  57. gradient_dict = hp_mapping.gradient_dict
  58. if hp_mapping.param_group_index not in gradient_dict or gradient_dict[
  59. hp_mapping.param_group_index] is None:
  60. raise ValueError(
  61. "Gradients are only available immediately after backward and before engine step"
  62. )
  63. lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][
  64. self._index_in_param_group]
  65. hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
  66. lp_frag_address = self._hp_mapping.lp_fragment_address
  67. reduce_fragment = torch.narrow(reduce_buffer,
  68. 0,
  69. lp_frag_address.start,
  70. lp_frag_address.numel)
  71. if self.view(-1).shape == hp_grad_fragment.shape:
  72. reduce_buffer.data.copy_(hp_grad_fragment.data)
  73. else:
  74. reduce_fragment.data.copy_(hp_grad_fragment.data)
  75. dist.all_reduce(reduce_buffer, group=self._dp_group)
  76. return reduce_buffer.reshape_as(self)
  77. def safe_get_full_fp32_param(param):
  78. """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
  79. Args:
  80. param (``torch.nn.Parameter``): A model parameter
  81. """
  82. # ZeRO stage 3 param
  83. if hasattr(param, 'ds_id'):
  84. return param._z3_optimizer.get_full_hp_param(param)
  85. # ZeRO stage 1, 2, and bf16_optimizer params
  86. if hasattr(param, '_hp_mapping'):
  87. return param.get_full_hp_param()
  88. return None
  89. def safe_get_full_optimizer_state(param, optim_state_key):
  90. """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  91. Args:
  92. param (``torch.nn.Parameter``): A model parameter
  93. """
  94. # ZeRO stage 3 param
  95. if hasattr(param, 'ds_id'):
  96. return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
  97. # ZeRO stage 1, 2, and bf16_optimizer params
  98. if hasattr(param, '_hp_mapping'):
  99. return param.get_full_hp_param(optim_state_key)
  100. return None
  101. # TODO: Figure out the correct return dtype
  102. def safe_get_full_grad(param):
  103. """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
  104. Args:
  105. param (``torch.nn.Parameter``): A model parameter
  106. """
  107. if param.grad is not None:
  108. return param.grad
  109. # ZeRO stage 3 param
  110. if hasattr(param, 'ds_id'):
  111. return param._z3_optimizer.get_fp32_grad_for_param(param)
  112. # ZeRO stage 1, 2, and bf16_optimizer params
  113. if hasattr(param, '_hp_mapping'):
  114. return param.get_full_hp_grad()
  115. return None
  116. def get_hp_fragment_mapping(lp_param,
  117. lp_start,
  118. flat_hp_partition,
  119. gradient_dict,
  120. offload_gradient_dict,
  121. use_offload,
  122. param_group_index,
  123. partition_start,
  124. partition_size,
  125. optimizer_state_dict):
  126. lp_end = lp_param.numel() + lp_start
  127. hp_start = partition_start
  128. hp_end = partition_start + partition_size
  129. fragment_start = max(lp_start, hp_start)
  130. fragment_end = min(lp_end, hp_end)
  131. assert fragment_start < fragment_end, \
  132. f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
  133. fragment_numel = fragment_end - fragment_start
  134. hp_frag_address = fragment_address(start=fragment_start - hp_start,
  135. numel=fragment_numel)
  136. hp_fragment_tensor = flat_hp_partition.narrow(0,
  137. hp_frag_address.start,
  138. hp_frag_address.numel)
  139. optim_fragment = {
  140. key: value.narrow(0,
  141. hp_frag_address.start,
  142. hp_frag_address.numel)
  143. for key,
  144. value in optimizer_state_dict.items()
  145. if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
  146. }
  147. lp_frag_address = fragment_address(start=fragment_start - lp_start,
  148. numel=fragment_numel)
  149. lp_fragment_tensor = lp_param.flatten().narrow(0,
  150. lp_frag_address.start,
  151. lp_frag_address.numel)
  152. return tensor_fragment(lp_fragment=lp_fragment_tensor,
  153. lp_fragment_address=lp_frag_address,
  154. hp_fragment=hp_fragment_tensor,
  155. hp_fragment_address=hp_frag_address,
  156. optim_fragment=optim_fragment,
  157. gradient_dict=gradient_dict,
  158. offload_gradient_dict=offload_gradient_dict,
  159. use_offload=use_offload,
  160. param_group_index=param_group_index)
  161. '''
  162. Logic for lp_param to hp_param mapping
  163. lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
  164. lp [ ][ ][ ][ ][ ] <-------- tensors
  165. flat_lp [ ] <-------- flat lp params
  166. flat_hp [ ] <------------------ flat hp partition on current rank
  167. full_hp [ ] <------- full flat hp params
  168. lp2
  169. full numel = 16
  170. lp_frag
  171. numel = 12
  172. frag_start = 3
  173. frag_end = 15
  174. hp_frag
  175. numel = 12
  176. frag_start = 0
  177. frag_end = 11
  178. hp_frag.copy_(lp_frag)
  179. lp3:
  180. full numel = 4
  181. lp_frag
  182. numel = 4
  183. start = 0
  184. end = 3
  185. hp_frag
  186. numel = 4
  187. start = 12
  188. end = 15
  189. lp4:
  190. full numel = 12
  191. lp_frag
  192. numel = 4
  193. start = 0
  194. end = 3
  195. hp_frag
  196. numel = 4
  197. start = 16
  198. end = 19
  199. Visual depiction of above
  200. lp { }
  201. flat_lp [ ]
  202. flat_hp ( )
  203. flat_lp [ { ( } ) ]
  204. lx hx ly hy
  205. ly-hx
  206. lp { }
  207. flat_lp [ ]
  208. flat_hp ( )
  209. flat_lp [ ( { ) } ]
  210. hx lx hy ly
  211. hy-lx
  212. lp { }
  213. flat_lp [ ]
  214. flat_hp ( )
  215. flat_lp [ ( { } ) ]
  216. hx lx ly hy
  217. ly-lx
  218. lp -> (lx, hy)
  219. flat_hp -> (hx, hy)
  220. '''