tensor_fragment.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from dataclasses import dataclass
  6. from deepspeed import comm as dist
  7. from typing import Dict
  8. @dataclass
  9. class fragment_address:
  10. numel: int
  11. start: int
  12. @dataclass
  13. class tensor_fragment:
  14. lp_fragment: torch.Tensor
  15. lp_fragment_address: fragment_address
  16. hp_fragment: torch.Tensor
  17. hp_fragment_address: fragment_address
  18. optim_fragment: Dict
  19. gradient_dict: Dict
  20. offload_gradient_dict: Dict
  21. use_offload: bool
  22. param_group_index: int
  23. def update_hp(self):
  24. self.hp_fragment.data.copy_(self.lp_fragment.data)
  25. def update_lp(self):
  26. self.lp_fragment.data.copy_(self.hp_fragment.data)
  27. def get_optim_state_fragment(self, key):
  28. if key in self.optim_fragment:
  29. return self.optim_fragment[key]
  30. else:
  31. raise ValueError(f'{key} not found in optimizer state fragment')
  32. def get_hp_fragment_address(self):
  33. return self.hp_fragment_address
  34. def get_optim_state_keys(self):
  35. return list(self.optim_fragment.keys())
  36. def get_hp_fragment(self, optim_state_key=None):
  37. if optim_state_key is None:
  38. return self.hp_fragment
  39. return self.get_optim_state_fragment(optim_state_key)
  40. def get_full_hp_param(self, optim_state_key=None):
  41. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  42. if self._hp_mapping is not None:
  43. lp_frag_address = self._hp_mapping.lp_fragment_address
  44. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  45. hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
  46. reduce_fragment.data.copy_(hp_fragment.data)
  47. dist.all_reduce(reduce_buffer, group=self._dp_group)
  48. return reduce_buffer.reshape_as(self)
  49. def set_full_hp_param(self, value, optim_state_key=None):
  50. if self._hp_mapping is not None:
  51. lp_frag_address = self._hp_mapping.lp_fragment_address
  52. value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
  53. hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
  54. hp_fragment.data.copy_(value_fragment.data)
  55. def get_full_hp_grad(self):
  56. reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
  57. if self._hp_mapping is not None:
  58. hp_mapping = self._hp_mapping
  59. if hp_mapping.use_offload:
  60. gradient_dict = hp_mapping.offload_gradient_dict
  61. else:
  62. gradient_dict = hp_mapping.gradient_dict
  63. if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
  64. raise ValueError("Gradients are only available immediately after backward and before engine step")
  65. lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
  66. hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
  67. lp_frag_address = self._hp_mapping.lp_fragment_address
  68. reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
  69. if self.view(-1).shape == hp_grad_fragment.shape:
  70. reduce_buffer.data.copy_(hp_grad_fragment.data)
  71. else:
  72. reduce_fragment.data.copy_(hp_grad_fragment.data)
  73. dist.all_reduce(reduce_buffer, group=self._dp_group)
  74. return reduce_buffer.reshape_as(self)
  75. def safe_get_full_fp32_param(param):
  76. """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
  77. Args:
  78. param (``torch.nn.Parameter``): A model parameter
  79. """
  80. # ZeRO stage 3 param
  81. if hasattr(param, 'ds_id'):
  82. return param._z3_optimizer.get_full_hp_param(param)
  83. # ZeRO stage 1, 2, and bf16_optimizer params
  84. if hasattr(param, '_hp_mapping'):
  85. return param.get_full_hp_param()
  86. return None
  87. def safe_set_full_fp32_param(param, value):
  88. """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
  89. Args:
  90. param (``torch.nn.Parameter``): A model parameter
  91. value (``torch.Tensor``): New value
  92. """
  93. # ZeRO stage 3 param
  94. if hasattr(param, 'ds_id'):
  95. param._z3_optimizer.set_full_hp_param(value, param)
  96. # ZeRO stage 1, 2, and bf16_optimizer params
  97. if hasattr(param, '_hp_mapping'):
  98. param.set_full_hp_param(value)
  99. def safe_get_full_optimizer_state(param, optim_state_key):
  100. """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  101. Args:
  102. param (``torch.nn.Parameter``): A model parameter
  103. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  104. """
  105. # ZeRO stage 3 param
  106. if hasattr(param, 'ds_id'):
  107. return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
  108. # ZeRO stage 1, 2, and bf16_optimizer params
  109. if hasattr(param, '_hp_mapping'):
  110. return param.get_full_hp_param(optim_state_key)
  111. return None
  112. def safe_set_full_optimizer_state(param, value, optim_state_key):
  113. """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
  114. Args:
  115. param (``torch.nn.Parameter``): A model parameter
  116. value (``torch.Tensor``): New value
  117. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  118. """
  119. # ZeRO stage 3 param
  120. if hasattr(param, 'ds_id'):
  121. param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)
  122. # ZeRO stage 1, 2, and bf16_optimizer params
  123. if hasattr(param, '_hp_mapping'):
  124. param.set_full_hp_param(value, optim_state_key)
  125. # TODO: Figure out the correct return dtype
  126. def safe_get_full_grad(param):
  127. """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
  128. Args:
  129. param (``torch.nn.Parameter``): A model parameter
  130. """
  131. if param.grad is not None:
  132. return param.grad
  133. # ZeRO stage 3 param
  134. if hasattr(param, 'ds_id'):
  135. return param._z3_optimizer.get_fp32_grad_for_param(param)
  136. # ZeRO stage 1, 2, and bf16_optimizer params
  137. if hasattr(param, '_hp_mapping'):
  138. return param.get_full_hp_grad()
  139. return None
  140. ### Local API START ###
  141. def safe_get_local_grad(param):
  142. """Get the fp32 gradient of a partitioned parameter.
  143. Args:
  144. param (``torch.nn.Parameter``): A model parameter
  145. """
  146. if param.grad is not None:
  147. return param.grad
  148. # ZeRO stage 3 param
  149. if hasattr(param, 'ds_id'):
  150. return param._z3_optimizer.get_local_fp32_grad_for_param(param)
  151. return None
  152. def safe_get_local_fp32_param(param):
  153. """Get the fp32 partitioned parameter.
  154. Args:
  155. param (``torch.nn.Parameter``): A model parameter
  156. """
  157. # ZeRO stage 3 param
  158. if hasattr(param, 'ds_id'):
  159. return param._z3_optimizer.get_local_fp32_param(param)
  160. return None
  161. def safe_get_local_optimizer_state(param, optim_state_key):
  162. """Get the fp32 optimizer state of a partitioned parameter.
  163. Args:
  164. param (``torch.nn.Parameter``): A model parameter
  165. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  166. """
  167. # ZeRO stage 3 param
  168. if hasattr(param, 'ds_id'):
  169. return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)
  170. return None
  171. def safe_set_local_optimizer_state(param, value, optim_state_key):
  172. """Update the fp32 optimizer state of a partitioned parameter.
  173. Args:
  174. param (``torch.nn.Parameter``): A model parameter
  175. value (``torch.Tensor``): New value
  176. optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
  177. """
  178. # ZeRO stage 3 param
  179. if hasattr(param, 'ds_id'):
  180. param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)
  181. def safe_set_local_fp32_param(param, value):
  182. """Update the partitioned fp32 parameter.
  183. Args:
  184. param (``torch.nn.Parameter``): A model parameter
  185. value (``torch.Tensor``): New value
  186. """
  187. # ZeRO stage 3 param
  188. if hasattr(param, 'ds_id'):
  189. param._z3_optimizer.set_local_hp_param(value, param)
  190. ### Local API END ###
  191. # TODO: Implement API for setting ZeRO partitioned gradients
  192. def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  193. param_group_index, partition_start, partition_size, optimizer_state_dict):
  194. lp_end = lp_param.numel() + lp_start
  195. hp_start = partition_start
  196. hp_end = partition_start + partition_size
  197. fragment_start = max(lp_start, hp_start)
  198. fragment_end = min(lp_end, hp_end)
  199. assert fragment_start < fragment_end, \
  200. f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
  201. fragment_numel = fragment_end - fragment_start
  202. hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
  203. hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  204. optim_fragment = {
  205. key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
  206. for key, value in optimizer_state_dict.items()
  207. if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
  208. }
  209. lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
  210. lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
  211. return tensor_fragment(lp_fragment=lp_fragment_tensor,
  212. lp_fragment_address=lp_frag_address,
  213. hp_fragment=hp_fragment_tensor,
  214. hp_fragment_address=hp_frag_address,
  215. optim_fragment=optim_fragment,
  216. gradient_dict=gradient_dict,
  217. offload_gradient_dict=offload_gradient_dict,
  218. use_offload=use_offload,
  219. param_group_index=param_group_index)
  220. '''
  221. Logic for lp_param to hp_param mapping
  222. lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
  223. lp [ ][ ][ ][ ][ ] <-------- tensors
  224. flat_lp [ ] <-------- flat lp params
  225. flat_hp [ ] <------------------ flat hp partition on current rank
  226. full_hp [ ] <------- full flat hp params
  227. lp2
  228. full numel = 16
  229. lp_frag
  230. numel = 12
  231. frag_start = 3
  232. frag_end = 15
  233. hp_frag
  234. numel = 12
  235. frag_start = 0
  236. frag_end = 11
  237. hp_frag.copy_(lp_frag)
  238. lp3:
  239. full numel = 4
  240. lp_frag
  241. numel = 4
  242. start = 0
  243. end = 3
  244. hp_frag
  245. numel = 4
  246. start = 12
  247. end = 15
  248. lp4:
  249. full numel = 12
  250. lp_frag
  251. numel = 4
  252. start = 0
  253. end = 3
  254. hp_frag
  255. numel = 4
  256. start = 16
  257. end = 19
  258. Visual depiction of above
  259. lp { }
  260. flat_lp [ ]
  261. flat_hp ( )
  262. flat_lp [ { ( } ) ]
  263. lx hx ly hy
  264. ly-hx
  265. lp { }
  266. flat_lp [ ]
  267. flat_hp ( )
  268. flat_lp [ ( { ) } ]
  269. hx lx hy ly
  270. hy-lx
  271. lp { }
  272. flat_lp [ ]
  273. flat_hp ( )
  274. flat_lp [ ( { } ) ]
  275. hx lx ly hy
  276. ly-lx
  277. lp -> (lx, hy)
  278. flat_hp -> (hx, hy)
  279. '''