gpt2.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  4. from ..policy import TransformerPolicy
  5. class DS_GPT2Container(BaseTransformerContainer):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. # All model specific things should be defined here instead of the base class.
  9. def create_module(self, config=None):
  10. _config = config if config is not None else self.ds_model_config
  11. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  12. self.module.config.scale_attention = self.scale_attention
  13. return self.module
  14. class HFGPT2LayerPolicy(TransformerPolicy):
  15. _orig_layer_class = None
  16. def __init__(self, client_module, inference=True):
  17. # HuggingFace GPT2 uses convolutional layer instead of linear layer
  18. super().__init__(inference, linear_layer=False)
  19. self.client_module = client_module
  20. try:
  21. import transformers
  22. HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block
  23. except:
  24. HFGPT2LayerPolicy._orig_layer_class = None
  25. def get_hidden_heads(self):
  26. return self.client_module.attn.embed_dim, \
  27. self.client_module.attn.num_heads
  28. def attention(self):
  29. return self.client_module.attn.c_attn.weight, \
  30. self.client_module.attn.c_attn.bias, \
  31. self.client_module.attn.c_proj.weight, \
  32. self.client_module.attn.c_proj.bias
  33. def mlp(self):
  34. return self.client_module.mlp.c_fc.weight, \
  35. self.client_module.mlp.c_fc.bias, \
  36. self.client_module.mlp.c_proj.weight, \
  37. self.client_module.mlp.c_proj.bias
  38. def layernorm(self):
  39. return self.client_module.ln_2.weight, \
  40. self.client_module.ln_2.bias, \
  41. self.client_module.ln_1.weight, \
  42. self.client_module.ln_1.bias