replace_policy.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from abc import ABC
  2. import torch
  3. class DSPolicy(ABC):
  4. def __init__(self, inference=True, linear_layer=True, scale_attention=True):
  5. self.inference = inference
  6. self.linear_layer = linear_layer
  7. self.scale_attention = scale_attention
  8. def attention(self):
  9. """
  10. Returns attention qkv and dense parameters
  11. weight: (3*hidden, hidden) and (hidden, hidden)
  12. bias: (3*hidden) and (hidden)
  13. """
  14. raise NotImplementedError
  15. def get_hidden_heads(self):
  16. """
  17. return hidden_size and number of heads
  18. """
  19. raise NotImplementedError
  20. def mlp(self):
  21. """
  22. Returns mlp intermediate and output
  23. weight: (intermediate, hidden) and (hidden, intermediate)
  24. bias: (intermediate) and (hidden)
  25. """
  26. raise NotImplementedError
  27. def layerNorm(self):
  28. """
  29. Returns LayerNorms used in transformer layer
  30. Post-Attention and pre/post layer norm
  31. gamma and beta with shape: (hidden)
  32. """
  33. raise NotImplementedError
  34. class HFBertLayerPolicy(DSPolicy):
  35. _orig_layer_class = None
  36. def __init__(self, client_module, inference=False, preln=False):
  37. super().__init__(inference)
  38. self.client_module = client_module
  39. self.preln = preln
  40. if HFBertLayerPolicy._orig_layer_class is None:
  41. try:
  42. import transformers
  43. HFBertLayerPolicy._orig_layer_class = transformers.models.bert.modeling_bert.BertLayer
  44. except:
  45. HFBertLayerPolicy._orig_layer_class = None
  46. def get_hidden_heads(self):
  47. return self.client_module.attention.self.query.weight.data.shape[1], \
  48. self.client_module.attention.self.num_attention_heads
  49. def attention(self):
  50. qw = self.client_module.attention.self.query.weight.data
  51. qb = self.client_module.attention.self.query.bias.data
  52. kw = self.client_module.attention.self.key.weight.data
  53. kb = self.client_module.attention.self.key.bias.data
  54. vw = self.client_module.attention.self.value.weight.data
  55. vb = self.client_module.attention.self.value.bias.data
  56. qkvw = torch.cat((qw, kw, vw), dim=0)
  57. qkvb = torch.cat((qb, kb, vb), dim=0)
  58. return self.linear_layer, \
  59. qkvw, \
  60. qkvb, \
  61. self.client_module.attention.output.dense.weight.data, \
  62. self.client_module.attention.output.dense.bias.data, \
  63. self.scale_attention
  64. def mlp(self):
  65. if self.preln:
  66. intermediate_ff = self.client_module.intermediate.dense_act
  67. else:
  68. intermediate_ff = self.client_module.intermediate.dense
  69. return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \
  70. self.client_module.output.dense.weight.data, \
  71. self.client_module.output.dense.bias.data
  72. def layerNorm(self):
  73. if self.preln:
  74. attention_layernorm = self.client_module.PostAttentionLayerNorm
  75. transformer_layernorm = self.client_module.PreAttentionLayerNorm
  76. else:
  77. attention_layernorm = self.client_module.attention.output.LayerNorm
  78. transformer_layernorm = self.client_module.output.LayerNorm
  79. return attention_layernorm.weight.data, \
  80. attention_layernorm.bias.data, \
  81. transformer_layernorm.weight.data, \
  82. transformer_layernorm.bias.data
  83. class HFGPTNEOLayerPolicy(DSPolicy):
  84. _orig_layer_class = None
  85. def __init__(self, client_module, inference=True):
  86. super().__init__(inference, scale_attention=False)
  87. self.client_module = client_module
  88. try:
  89. import transformers
  90. HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
  91. except:
  92. HFGPTNEOLayerPolicy._orig_layer_class = None
  93. def get_hidden_heads(self):
  94. return self.client_module.attn.attention.q_proj.weight.data.shape[1], \
  95. self.client_module.attn.attention.num_heads
  96. def attention(self):
  97. qw = self.client_module.attn.attention.q_proj.weight.data
  98. kw = self.client_module.attn.attention.k_proj.weight.data
  99. vw = self.client_module.attn.attention.v_proj.weight.data
  100. qkvw = torch.cat((qw, kw, vw), dim=0)
  101. return self.linear_layer, \
  102. qkvw, \
  103. None, \
  104. self.client_module.attn.attention.out_proj.weight.data, \
  105. self.client_module.attn.attention.out_proj.bias.data, \
  106. self.scale_attention
  107. def mlp(self):
  108. return self.linear_layer, \
  109. self.client_module.mlp.c_fc.weight.data, \
  110. self.client_module.mlp.c_fc.bias.data, \
  111. self.client_module.mlp.c_proj.weight.data, \
  112. self.client_module.mlp.c_proj.bias.data
  113. def layerNorm(self):
  114. return self.client_module.ln_2.weight.data, \
  115. self.client_module.ln_2.bias.data, \
  116. self.client_module.ln_1.weight.data, \
  117. self.client_module.ln_1.bias.data
  118. class MegatronLayerPolicy(DSPolicy):
  119. _orig_layer_class = None
  120. def __init__(self, client_module, version=0, inference=True):
  121. super().__init__(inference)
  122. self.client_module = client_module
  123. # we use megatron version to differentiate between the old and new
  124. # megatron-lm source code
  125. self.version = version
  126. if MegatronLayerPolicy._orig_layer_class is None:
  127. try:
  128. import megatron
  129. from megatron.model.transformer import ParallelTransformerLayer
  130. MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
  131. except ImportError:
  132. MegatronLayerPolicy._orig_layer_class = None
  133. def get_hidden_heads(self):
  134. return self.client_module.attention.query_key_value.weight.data.shape[1], \
  135. self.client_module.attention.num_attention_heads
  136. def attention(self):
  137. if self.inference:
  138. if self.version == 0:
  139. attention = self.client_module.attention
  140. else:
  141. attention = self.client_module.self_attention
  142. return self.linear_layer, \
  143. attention.query_key_value.weight.data, \
  144. attention.query_key_value.bias.data, \
  145. attention.dense.weight.data, \
  146. attention.dense.bias.data, \
  147. self.scale_attention
  148. def mlp(self):
  149. return self.linear_layer, \
  150. self.client_module.mlp.dense_h_to_4h.weight.data, \
  151. self.client_module.mlp.dense_h_to_4h.bias.data, \
  152. self.client_module.mlp.dense_4h_to_h.weight.data, \
  153. self.client_module.mlp.dense_4h_to_h.bias.data
  154. def layerNorm(self):
  155. return self.client_module.post_attention_layernorm.weight.data, \
  156. self.client_module.post_attention_layernorm.bias.data, \
  157. self.client_module.input_layernorm.weight.data, \
  158. self.client_module.input_layernorm.bias.data
  159. class HFGPT2LayerPolicy(DSPolicy):
  160. _orig_layer_class = None
  161. def __init__(self, client_module, inference=True):
  162. # HuggingFace GPT2 uses convolutional layer instead of linear layer
  163. super().__init__(inference, linear_layer=False)
  164. self.client_module = client_module
  165. try:
  166. import transformers
  167. HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block
  168. except ImportError:
  169. HFGPT2LayerPolicy._orig_layer_class = None
  170. def get_hidden_heads(self):
  171. return self.client_module.attn.embed_dim, \
  172. self.client_module.attn.num_heads
  173. def attention(self):
  174. return self.linear_layer, \
  175. self.client_module.attn.c_attn.weight.data, \
  176. self.client_module.attn.c_attn.bias.data, \
  177. self.client_module.attn.c_proj.weight.data, \
  178. self.client_module.attn.c_proj.bias.data, \
  179. self.scale_attention
  180. def mlp(self):
  181. return self.linear_layer, \
  182. self.client_module.mlp.c_fc.weight.data, \
  183. self.client_module.mlp.c_fc.bias.data, \
  184. self.client_module.mlp.c_proj.weight.data, \
  185. self.client_module.mlp.c_proj.bias.data
  186. def layerNorm(self):
  187. return self.client_module.ln_2.weight.data, \
  188. self.client_module.ln_2.bias.data, \
  189. self.client_module.ln_1.weight.data, \
  190. self.client_module.ln_1.bias.data
  191. replace_policies = [
  192. HFBertLayerPolicy,
  193. HFGPTNEOLayerPolicy,
  194. MegatronLayerPolicy,
  195. HFGPT2LayerPolicy,
  196. ]