distil_bert.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  4. import torch
  5. from torch.nn.parameter import Parameter
  6. from ..policy import TransformerPolicy
  7. class DS_DistilBERTContainer(BaseTransformerContainer):
  8. def __init__(self, **kwargs):
  9. super().__init__(**kwargs)
  10. # All model specific things should be defined here instead of the base class.
  11. self.triangular_masking = False
  12. self.return_single_tuple = True
  13. def create_module(self, config=None):
  14. _config = config if config is not None else self.ds_model_config
  15. self.module = DeepSpeedBERTInference(_config, mp_group=self.mp_group)
  16. self.module.config.scale_attention = self.scale_attention
  17. return self.module
  18. class HFDistilBertLayerPolicy(TransformerPolicy):
  19. _orig_layer_class = None
  20. def __init__(self, client_module, inference=False, preln=False):
  21. super().__init__(inference)
  22. self.client_module = client_module
  23. self.preln = preln
  24. self.cuda_graph_supported = True
  25. if HFDistilBertLayerPolicy._orig_layer_class is None:
  26. try:
  27. import transformers
  28. HFDistilBertLayerPolicy._orig_layer_class = [
  29. transformers.models.distilbert.modeling_distilbert.TransformerBlock,
  30. ]
  31. except:
  32. HFDistilBertLayerPolicy._orig_layer_class = None
  33. def get_hidden_heads(self):
  34. return self.client_module.attention.q_lin.weight.shape[1], \
  35. self.client_module.attention.n_heads
  36. def attention(self):
  37. qw = self.client_module.attention.q_lin.weight
  38. qb = self.client_module.attention.q_lin.bias
  39. kw = self.client_module.attention.k_lin.weight
  40. kb = self.client_module.attention.k_lin.bias
  41. vw = self.client_module.attention.v_lin.weight
  42. vb = self.client_module.attention.v_lin.bias
  43. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
  44. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0))
  45. return qkvw, \
  46. qkvb, \
  47. self.client_module.attention.out_lin.weight, \
  48. self.client_module.attention.out_lin.bias
  49. def mlp(self):
  50. intermediate_ff = self.client_module.ffn.lin1
  51. return intermediate_ff.weight, intermediate_ff.bias, \
  52. self.client_module.ffn.lin2.weight, \
  53. self.client_module.ffn.lin2.bias
  54. def layernorm(self):
  55. attention_layernorm = self.client_module.sa_layer_norm
  56. transformer_layernorm = self.client_module.output_layer_norm
  57. return attention_layernorm.weight, \
  58. attention_layernorm.bias, \
  59. transformer_layernorm.weight, \
  60. transformer_layernorm.bias