bert.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  4. import torch
  5. from torch.nn.parameter import Parameter
  6. from ..policy import TransformerPolicy
  7. class DS_BERTContainer(BaseTransformerContainer):
  8. def __init__(self, **kwargs):
  9. super().__init__(**kwargs)
  10. # All model specific things should be defined here instead of the base class.
  11. self.return_tuple = True
  12. self.triangular_masking = False
  13. def create_module(self, config=None):
  14. _config = config if config is not None else self.ds_model_config
  15. self.module = DeepSpeedBERTInference(_config, mp_group=self.mp_group)
  16. self.module.config.scale_attention = self.scale_attention
  17. return self.module
  18. class HFBertLayerPolicy(TransformerPolicy):
  19. def __init__(self, client_module, inference=False):
  20. super().__init__(inference, pre_attn_norm=False)
  21. self.client_module = client_module
  22. self.cuda_graph_supported = True
  23. if HFBertLayerPolicy._orig_layer_class is None:
  24. try:
  25. import transformers
  26. HFBertLayerPolicy._orig_layer_class = [
  27. transformers.models.bert.modeling_bert.BertLayer,
  28. transformers.models.roberta.modeling_roberta.RobertaLayer
  29. ]
  30. except:
  31. HFBertLayerPolicy._orig_layer_class = None
  32. def get_hidden_heads(self):
  33. return self.client_module.attention.self.query.weight.shape[1], \
  34. self.client_module.attention.self.num_attention_heads
  35. def attention(self):
  36. qw = self.client_module.attention.self.query.weight
  37. qb = self.client_module.attention.self.query.bias
  38. kw = self.client_module.attention.self.key.weight
  39. kb = self.client_module.attention.self.key.bias
  40. vw = self.client_module.attention.self.value.weight
  41. vb = self.client_module.attention.self.value.bias
  42. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
  43. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
  44. return qkvw, \
  45. qkvb, \
  46. self.client_module.attention.output.dense.weight, \
  47. self.client_module.attention.output.dense.bias, \
  48. def mlp(self):
  49. if self.pre_attn_norm:
  50. intermediate_ff = self.client_module.intermediate.dense_act
  51. else:
  52. intermediate_ff = self.client_module.intermediate.dense
  53. return intermediate_ff.weight, intermediate_ff.bias, \
  54. self.client_module.output.dense.weight, \
  55. self.client_module.output.dense.bias
  56. def layernorm(self):
  57. if self.pre_attn_norm:
  58. attention_layernorm = self.client_module.PostAttentionLayerNorm
  59. transformer_layernorm = self.client_module.PreAttentionLayerNorm
  60. else:
  61. attention_layernorm = self.client_module.attention.output.LayerNorm
  62. transformer_layernorm = self.client_module.output.LayerNorm
  63. return attention_layernorm.weight, \
  64. attention_layernorm.bias, \
  65. transformer_layernorm.weight, \
  66. transformer_layernorm.bias