bert.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from ..policy import TransformerPolicy
  9. class DS_BERTContainer(BaseTransformerContainer):
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. # All model specific things should be defined here instead of the base class.
  13. self.return_tuple = True
  14. self.triangular_masking = False
  15. self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
  16. def create_module(self, config=None):
  17. _config = config if config is not None else self.ds_model_config
  18. self.module = DeepSpeedBERTInference(_config, mp_group=self.mp_group)
  19. self.module.config.scale_attention = self.scale_attention
  20. return self.module
  21. class HFBertLayerPolicy(TransformerPolicy):
  22. def __init__(self, client_module, inference=False):
  23. super().__init__(inference, pre_attn_norm=False)
  24. self.client_module = client_module
  25. self.cuda_graph_supported = True
  26. if HFBertLayerPolicy._orig_layer_class is None:
  27. try:
  28. import transformers
  29. HFBertLayerPolicy._orig_layer_class = [
  30. transformers.models.bert.modeling_bert.BertLayer,
  31. transformers.models.roberta.modeling_roberta.RobertaLayer
  32. ]
  33. except:
  34. HFBertLayerPolicy._orig_layer_class = None
  35. def get_hidden_heads(self):
  36. if self.pre_attn_norm:
  37. attention_layernorm = self.client_module.PostAttentionLayerNorm
  38. else:
  39. attention_layernorm = self.client_module.attention.output.LayerNorm
  40. return self.client_module.attention.self.query.weight.shape[1], \
  41. self.client_module.attention.self.num_attention_heads, \
  42. attention_layernorm.eps, \
  43. DEFAULT_INTERMEDIATE_SIZE
  44. def attention(self, enable_training=False):
  45. qw = self.client_module.attention.self.query.weight
  46. qb = self.client_module.attention.self.query.bias
  47. kw = self.client_module.attention.self.key.weight
  48. kb = self.client_module.attention.self.key.bias
  49. vw = self.client_module.attention.self.value.weight
  50. vb = self.client_module.attention.self.value.bias
  51. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  52. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
  53. return qkvw, \
  54. qkvb, \
  55. self.client_module.attention.output.dense.weight, \
  56. self.client_module.attention.output.dense.bias, \
  57. def mlp(self, enable_training=False):
  58. if self.pre_attn_norm:
  59. intermediate_ff = self.client_module.intermediate.dense_act
  60. else:
  61. intermediate_ff = self.client_module.intermediate.dense
  62. return intermediate_ff.weight, intermediate_ff.bias, \
  63. self.client_module.output.dense.weight, \
  64. self.client_module.output.dense.bias
  65. def layernorm(self):
  66. if self.pre_attn_norm:
  67. attention_layernorm = self.client_module.PostAttentionLayerNorm
  68. transformer_layernorm = self.client_module.PreAttentionLayerNorm
  69. else:
  70. attention_layernorm = self.client_module.attention.output.LayerNorm
  71. transformer_layernorm = self.client_module.output.LayerNorm
  72. return attention_layernorm.weight, \
  73. attention_layernorm.bias, \
  74. transformer_layernorm.weight, \
  75. transformer_layernorm.bias