tensor_fragment.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from dataclasses import dataclass
  6. from deepspeed import comm as dist
  7. from typing import Dict
  8. @dataclass
  9. class fragment_address:
  10. numel: int
  11. start: int
  12. @dataclass
  13. class tensor_fragment:
  14. lp_fragment: torch.Tensor
  15. lp_fragment_address: fragment_address
  16. hp_fragment: torch.Tensor
  17. hp_fragment_address: fragment_address
  18. optim_fragment: Dict
  19. gradient_dict: Dict
  20. offload_gradient_dict: Dict
  21. use_offload: bool
  22. param_group_index: int
  23. def update_hp(self):
  24. self.hp_fragment.data.copy_(self.lp_fragment.data)
  25. def update_lp(self):
  26. self.lp_fragment.data.copy_(self.hp_fragment.data)
  27. def get_optim_state_fragment(self, key):
  28. if key in self.optim_fragment:
  29. return self.optim_fragment[key]
  30. else:
  31. raise ValueError(f'{key} not found in optimizer state fragment')
  32. def get_hp_fragment_address(self):
  33. return self.hp_fragment_address
  34. def get_optim_state_keys(self):
  35. return list(self.optim_fragment.keys())
  36. def get_hp_fragment(self, optim_state_key=None):
  37. if optim_state_key is None:
  38. return self.hp_fragment
  39. return self.get_optim_state_fragment(optim_state_key)
  40. def get_full_hp_param(self, optim_state_key=None):
  41. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  42. if self._hp_mapping is not None:
  43. lp_frag_address = self._hp_mapping.lp_fragment_address
  44. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  45. hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
  46. reduce_fragment.data.copy_(hp_fragment.data)
  47. dist.all_reduce(reduce_buffer, group=self._dp_group)
  48. return reduce_buffer.reshape_as(self)
  49. def set_full_hp_param(self, value, optim_state_key=None):
  50. if self._hp_mapping is not None:
  51. lp_frag_address = self._hp_mapping.lp_fragment_address
  52. value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
  53. hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
  54. hp_fragment.data.copy_(value_fragment.data)
  55. def get_full_hp_grad(self):
  56. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  57. if self._hp_mapping is not None:
  58. hp_mapping = self._hp_mapping
  59. if hp_mapping.use_offload:
  60. gradient_dict = hp_mapping.offload_gradient_dict
  61. else:
  62. gradient_dict = hp_mapping.gradient_dict
  63. if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
  64. raise ValueError("Gradients are only available immediately after backward and before engine step")
  65. lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
  66. hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
  67. lp_frag_address = self._hp_mapping.lp_fragment_address
  68. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  69. if self.view(-1).shape == hp_grad_fragment.shape:
  70. reduce_buffer.data.copy_(hp_grad_fragment.data)
  71. else:
  72. reduce_fragment.data.copy_(hp_grad_fragment.data)
  73. dist.all_reduce(reduce_buffer, group=self._dp_group)
  74. return reduce_buffer.reshape_as(self)
  75. def safe_get_full_fp32_param(param):
  76. """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
  77. Args:
  78. param (``torch.nn.Parameter``): A model parameter
  79. """
  80. # ZeRO stage 3 param
  81. if hasattr(param, 'ds_id'):
  82. return param._z3_optimizer.get_full_hp_param(param)
  83. # ZeRO stage 1, 2, and bf16_optimizer params
  84. if hasattr(param, '_hp_mapping'):
  85. return param.get_full_hp_param()
  86. return None
  87. def safe_set_full_fp32_param(param, value):
  88. """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
  89. Args:
  90. param (``torch.nn.Parameter``): A model parameter
  91. value (``torch.Tensor``): New value
  92. """
  93. # ZeRO stage 3 param
  94. if hasattr(param, 'ds_id'):
  95. param._z3_optimizer.set_full_hp_param(value, param)
  96. # ZeRO stage 1, 2, and bf16_optimizer params
  97. if hasattr(param, '_hp_mapping'):
  98. param.set_full_hp_param(value)
  99. def safe_get_full_optimizer_state(param, optim_state_key):
  100. """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  101. Args:
  102. param (``torch.nn.Parameter``): A model parameter
  103. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  104. """
  105. # ZeRO stage 3 param
  106. if hasattr(param, 'ds_id'):
  107. return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
  108. # ZeRO stage 1, 2, and bf16_optimizer params
  109. if hasattr(param, '_hp_mapping'):
  110. return param.get_full_hp_param(optim_state_key)
  111. return None
  112. def safe_set_full_optimizer_state(param, value, optim_state_key):
  113. """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  114. Args:
  115. param (``torch.nn.Parameter``): A model parameter
  116. value (``torch.Tensor``): New value
  117. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  118. """
  119. # ZeRO stage 3 param
  120. if hasattr(param, 'ds_id'):
  121. param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)
  122. # ZeRO stage 1, 2, and bf16_optimizer params
  123. if hasattr(param, '_hp_mapping'):
  124. param.set_full_hp_param(value, optim_state_key)
  125. # TODO: Figure out the correct return dtype
  126. def safe_get_full_grad(param):
  127. """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
  128. Args:
  129. param (``torch.nn.Parameter``): A model parameter
  130. """
  131. if param.grad is not None:
  132. return param.grad
  133. # ZeRO stage 3 param
  134. if hasattr(param, 'ds_id'):
  135. return param._z3_optimizer.get_fp32_grad_for_param(param)
  136. # ZeRO stage 1, 2, and bf16_optimizer params
  137. if hasattr(param, '_hp_mapping'):
  138. return param.get_full_hp_grad()
  139. return None
  140. # TODO: Implement API for setting ZeRO partitioned gradients
  141. def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  142. param_group_index, partition_start, partition_size, optimizer_state_dict):
  143. lp_end = lp_param.numel() + lp_start
  144. hp_start = partition_start
  145. hp_end = partition_start + partition_size
  146. fragment_start = max(lp_start, hp_start)
  147. fragment_end = min(lp_end, hp_end)
  148. assert fragment_start < fragment_end, \
  149. f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
  150. fragment_numel = fragment_end - fragment_start
  151. hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
  152. hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  153. optim_fragment = {
  154. key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  155. for key, value in optimizer_state_dict.items()
  156. if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
  157. }
  158. lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
  159. lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
  160. return tensor_fragment(lp_fragment=lp_fragment_tensor,
  161. lp_fragment_address=lp_frag_address,
  162. hp_fragment=hp_fragment_tensor,
  163. hp_fragment_address=hp_frag_address,
  164. optim_fragment=optim_fragment,
  165. gradient_dict=gradient_dict,
  166. offload_gradient_dict=offload_gradient_dict,
  167. use_offload=use_offload,
  168. param_group_index=param_group_index)
  169. '''
  170. Logic for lp_param to hp_param mapping
  171. lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
  172. lp [ ][ ][ ][ ][ ] <-------- tensors
  173. flat_lp [ ] <-------- flat lp params
  174. flat_hp [ ] <------------------ flat hp partition on current rank
  175. full_hp [ ] <------- full flat hp params
  176. lp2
  177. full numel = 16
  178. lp_frag
  179. numel = 12
  180. frag_start = 3
  181. frag_end = 15
  182. hp_frag
  183. numel = 12
  184. frag_start = 0
  185. frag_end = 11
  186. hp_frag.copy_(lp_frag)
  187. lp3:
  188. full numel = 4
  189. lp_frag
  190. numel = 4
  191. start = 0
  192. end = 3
  193. hp_frag
  194. numel = 4
  195. start = 12
  196. end = 15
  197. lp4:
  198. full numel = 12
  199. lp_frag
  200. numel = 4
  201. start = 0
  202. end = 3
  203. hp_frag
  204. numel = 4
  205. start = 16
  206. end = 19
  207. Visual depiction of above
  208. lp { }
  209. flat_lp [ ]
  210. flat_hp ( )
  211. flat_lp [ { ( } ) ]
  212. lx hx ly hy
  213. ly-hx
  214. lp { }
  215. flat_lp [ ]
  216. flat_hp ( )
  217. flat_lp [ ( { ) } ]
  218. hx lx hy ly
  219. hy-lx
  220. lp { }
  221. flat_lp [ ]
  222. flat_hp ( )
  223. flat_lp [ ( { } ) ]
  224. hx lx ly hy
  225. ly-lx
  226. lp -> (lx, hy)
  227. flat_hp -> (hx, hy)
  228. '''