hybrid_engine.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import ABC, abstractmethod
  5. from typing import List, Tuple
  6. import torch
  7. class HybridEngineContainer(ABC):
  8. """
  9. This container identifies which methods need to be overridden in addition to
  10. the base container to enable use in the RLHF pipeline. These methods are not
  11. necessary for inference alone.
  12. NOTE: If you are using this feature with a container that
  13. also inherits from `MetaTensorContainer`, ensure that `MetaTensorContainer`
  14. is inherited before `HybridEngineContainer` in the class definition.
  15. """
  16. def initialize_tensors(self, enable_training=False):
  17. """
  18. Same purposes as the base container, but also grabs the hooks for any LoRA
  19. parameters. If it's necessary to override specific sub-components of the model,
  20. it's best to augment the specific `set_[component]` itself rather than modifying
  21. the `initialize_tensors` method. See the `HybridSplitQKVContainer` for an example.
  22. """
  23. super().initialize_tensors(enable_training=enable_training)
  24. self.set_lora_params()
  25. def transform_for_training(self):
  26. """
  27. If the views on certain parameters are largely incompatible, it may be necessary to do
  28. more substantial transformations to the parameters. This method should be overridden to
  29. transform the inference format to what is necessary for training.
  30. """
  31. pass
  32. def transform_for_inference(self):
  33. """
  34. If the views on certain parameters are largely incompatible, it may be necessary to do
  35. more substantial transformations to the parameters. This method should be overridden to
  36. transform the training format to what is necessary for inference.
  37. """
  38. pass
  39. @abstractmethod
  40. def set_lora_params(self):
  41. """
  42. If available, set the LoRA parameters for the module. An implementation
  43. for this would iterate over all parameters of the model and use the `maybe_get_lora` helper
  44. method to check if the parameter does in fact have any LoRA params.
  45. """
  46. raise NotImplementedError("A set_lora_params() function must be defined for the relevant parameters.")
  47. @abstractmethod
  48. def get_lora_matched_pair(self):
  49. """Get the pair of lora params and its matched model parameters."""
  50. raise NotImplementedError("get_lora_matched_pair() must be defined for the relevant parameters.")
  51. def fuse_lora(self):
  52. """Fuse the LoRA parameters for the inference mode."""
  53. for maybe_lora_param, param in self.get_lora_matched_pair():
  54. if len(maybe_lora_param) == 3:
  55. lora_right_weight, \
  56. lora_left_weight, \
  57. lora_scaling = maybe_lora_param
  58. param.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
  59. def unfuse_lora(self):
  60. """Unfuse the LoRA parameters for the training mode."""
  61. for maybe_lora_param, param in self.get_lora_matched_pair():
  62. if len(maybe_lora_param) == 3:
  63. lora_right_weight, \
  64. lora_left_weight, \
  65. lora_scaling = maybe_lora_param
  66. param.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
  67. def apply_tensor_parallelism(self, mp_replace, reversed_dim=False):
  68. """
  69. Add support for reversed dim in tensor parallelism. If necessary, override
  70. the called methods to handle partitioned weights (i.e. if qkv is split, override
  71. the `attention_qkv_mp` method). If the model component is not split, it should
  72. be safe to use the default implementation.
  73. """
  74. # Setup the new Attention module
  75. self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim)
  76. self.attention_o_mp(mp_replace, reversed_dim=reversed_dim)
  77. # Setup the new MLP module
  78. self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim)
  79. self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim)
  80. # Apply weight quantization
  81. # TODO(cmikeh2): Re-enable this once verified
  82. #self.apply_weight_quantization()
  83. def _release_params(self, param_pairs: List[Tuple[torch.Tensor, torch.Tensor]]):
  84. """
  85. Helper for `release_[component]` methods. Accepts a list of tuples where the first
  86. element is the module param that needs to be deleted, and the second is the reassignment
  87. from the container.
  88. """
  89. for module_param, container_param in param_pairs:
  90. if module_param is not None:
  91. del module_param
  92. module_param = container_param
  93. def release_memory(self):
  94. """
  95. Delete module parameters if they exist and point them back to the container. The primary
  96. purpose of this is for TP-inference with ZeRO-3. In this scenario, we need to delete the
  97. parameters we've created for inference to free their memory.
  98. """
  99. general_params = [
  100. (self.module.attention.attn_ow, self.dense_w),
  101. (self.module.attention.attn_ob, self.dense_b),
  102. (self.module.mlp.attn_nw, self.attn_nw),
  103. (self.module.mlp.attn_nb, self.attn_nb),
  104. (self.module.norm_w, self.input_nw),
  105. (self.module.norm_b, self.input_nb),
  106. ]
  107. self._release_params(general_params)
  108. self.release_qkv()
  109. self.release_mlp()
  110. def release_qkv(self):
  111. """
  112. Release for QKV parameters (as well as any aliases).
  113. """
  114. qkv_params = [
  115. (self.module.attention.attn_qkvw, self.qkvw),
  116. (self.module.attention.attn_qkvb, self.qkvb),
  117. ]
  118. self._release_params(qkv_params)
  119. def release_mlp(self):
  120. """
  121. Release for MLP parameters (as well as any aliases).
  122. """
  123. mlp_params = [
  124. (self.module.mlp.inter_w, self._h4h_w),
  125. (self.module.mlp.inter_b, self._h4h_b),
  126. (self.module.mlp.output_w, self._4hh_w),
  127. (self.module.mlp.output_b, self._4hh_b),
  128. ]
  129. self._release_params(mlp_params)
  130. def reset_params(self):
  131. """
  132. The purpose of reset params is to get the weights from the FP16 training
  133. copy of the model and copy to them to contiguous inference view. This only needs
  134. to be performed when the container parameters cannot be used directly for inference.
  135. """
  136. self.reset_qkv()
  137. self.reset_mlp()
  138. def reset_qkv(self):
  139. """
  140. Perform any necessary resets of the model parameters for the QKV components.
  141. """
  142. pass
  143. def reset_mlp(self):
  144. """
  145. Perform any necessary resets of the model parameters for the MLP components.
  146. """
  147. pass
  148. def get_lora_params(self):
  149. """
  150. Return a list of all parameters that would have LoRA for the module.
  151. """
  152. if not hasattr(self, "lora_params"):
  153. self.set_lora_params()
  154. return self.lora_params
  155. def set_params_wo_copy(self, Z3_enabled=False):
  156. """
  157. Rather than copying into, set the parameters directly. This is necessary to provide
  158. an inexpensive (low-memory-overhead) view onto the FP16 forward weights.
  159. """
  160. self.module.mlp.attn_nw = self.attn_nw
  161. self.module.mlp.attn_nb = self.attn_nb
  162. self.module.norm_w = self.input_nw
  163. self.module.norm_b = self.input_nb
  164. self.set_attn_params_wo_copy(Z3_enabled=Z3_enabled)
  165. self.set_mlp_params_wo_copy(Z3_enabled=Z3_enabled)
  166. def set_attn_params_wo_copy(self, **kwargs):
  167. """
  168. Narrower sub-method for finer grained overriding.
  169. """
  170. self.module.attention.attn_ow = self.dense_w
  171. self.module.attention.attn_ob = self.dense_b
  172. self.module.attention.attn_qkvw = self.qkvw
  173. self.module.attention.attn_qkvb = self.qkvb
  174. def set_mlp_params_wo_copy(self, **kwargs):
  175. """
  176. Narrower sub-method for finer grained overriding.
  177. """
  178. self.module.mlp.inter_w = self._h4h_w
  179. self.module.mlp.inter_b = self._h4h_b
  180. self.module.mlp.output_w = self._4hh_w
  181. self.module.mlp.output_b = self._4hh_b