distil_bert.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from ..policy import TransformerPolicy
  9. class DS_DistilBERTContainer(BaseTransformerContainer):
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. # All model specific things should be defined here instead of the base class.
  13. self.triangular_masking = False
  14. self.return_single_tuple = True
  15. self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
  16. def create_module(self, config=None):
  17. _config = config if config is not None else self.ds_model_config
  18. self.module = DeepSpeedBERTInference(_config, mp_group=self.mp_group)
  19. self.module.config.scale_attention = self.scale_attention
  20. return self.module
  21. class HFDistilBertLayerPolicy(TransformerPolicy):
  22. _orig_layer_class = None
  23. def __init__(self, client_module, inference=False, preln=False):
  24. super().__init__(inference)
  25. self.client_module = client_module
  26. self.preln = preln
  27. self.cuda_graph_supported = True
  28. if HFDistilBertLayerPolicy._orig_layer_class is None:
  29. try:
  30. import transformers
  31. HFDistilBertLayerPolicy._orig_layer_class = [
  32. transformers.models.distilbert.modeling_distilbert.TransformerBlock,
  33. ]
  34. except:
  35. HFDistilBertLayerPolicy._orig_layer_class = None
  36. def get_hidden_heads(self):
  37. return self.client_module.attention.q_lin.weight.shape[1], \
  38. self.client_module.attention.n_heads, \
  39. self.client_module.sa_layer_norm.eps, \
  40. DEFAULT_INTERMEDIATE_SIZE
  41. def attention(self, enable_training=False):
  42. qw = self.client_module.attention.q_lin.weight
  43. qb = self.client_module.attention.q_lin.bias
  44. kw = self.client_module.attention.k_lin.weight
  45. kb = self.client_module.attention.k_lin.bias
  46. vw = self.client_module.attention.v_lin.weight
  47. vb = self.client_module.attention.v_lin.bias
  48. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  49. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
  50. return qkvw, \
  51. qkvb, \
  52. self.client_module.attention.out_lin.weight, \
  53. self.client_module.attention.out_lin.bias
  54. def mlp(self, enable_training=False):
  55. intermediate_ff = self.client_module.ffn.lin1
  56. return intermediate_ff.weight, intermediate_ff.bias, \
  57. self.client_module.ffn.lin2.weight, \
  58. self.client_module.ffn.lin2.bias
  59. def layernorm(self):
  60. attention_layernorm = self.client_module.sa_layer_norm
  61. transformer_layernorm = self.client_module.output_layer_norm
  62. return attention_layernorm.weight, \
  63. attention_layernorm.bias, \
  64. transformer_layernorm.weight, \
  65. transformer_layernorm.bias