bloom.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from .features.meta_tensor import MetaTensorContainer
  4. from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
  5. from ..policy import TransformerPolicy
  6. from ..policy import transformer_param_names
  7. from ..policy import maybe_copy
  8. supported_models = {None}
  9. class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. # All model specific things should be defined here instead of the base class.
  13. self.bigscience_bloom = True
  14. def create_module(self, config=None):
  15. _config = config if config is not None else self.ds_model_config
  16. self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
  17. self.module.config.scale_attention = self.scale_attention
  18. return self.module
  19. def attention_qkv_mp(self, mp_replace):
  20. self.module.attention.attn_qkvw = mp_replace.copy(
  21. self.module.attention.attn_qkvw,
  22. self.qkvw)
  23. self.module.attention.attn_qkvb = mp_replace.copy(
  24. self.module.attention.attn_qkvb,
  25. self.qkvb)
  26. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  27. param_names = (
  28. 'self_attention.query_key_value.weight', \
  29. 'self_attention.query_key_value.bias', \
  30. 'self_attention.dense.weight', \
  31. 'self_attention.dense.bias', \
  32. 'mlp.dense_h_to_4h.weight', \
  33. 'mlp.dense_h_to_4h.bias', \
  34. 'mlp.dense_4h_to_h.weight', \
  35. 'mlp.dense_4h_to_h.bias', \
  36. 'post_attention_layernorm.weight', \
  37. 'post_attention_layernorm.bias', \
  38. 'input_layernorm.weight', \
  39. 'input_layernorm.bias'
  40. )
  41. for i in range(0, 2):
  42. maybe_copy(module.attention,
  43. sd,
  44. weight_quantizer,
  45. mp_replace,
  46. transformer_param_names[i],
  47. prefix + param_names[i],
  48. qkv=True,
  49. megatron_v2=self.policy.is_megatron_v2,
  50. split_qkv=self.policy.split_qkv)
  51. for i in range(2, 4):
  52. maybe_copy(module.attention,
  53. sd,
  54. weight_quantizer,
  55. mp_replace,
  56. transformer_param_names[i],
  57. prefix + param_names[i])
  58. for i in range(4, 10):
  59. maybe_copy(module.mlp,
  60. sd,
  61. weight_quantizer,
  62. mp_replace,
  63. transformer_param_names[i],
  64. prefix + param_names[i])
  65. for i in range(10, 12):
  66. maybe_copy(module,
  67. sd,
  68. weight_quantizer,
  69. mp_replace,
  70. transformer_param_names[i],
  71. prefix + param_names[i])
  72. class BLOOMLayerPolicy(TransformerPolicy):
  73. _orig_layer_class = None
  74. def __init__(self,
  75. client_module,
  76. inference=True,
  77. use_load_prefix=True,
  78. split_qkv=False):
  79. super().__init__(inference,
  80. linear_layer=True,
  81. use_load_prefix=use_load_prefix,
  82. split_qkv=split_qkv)
  83. self.client_module = client_module
  84. try:
  85. import transformers
  86. BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
  87. global supported_models
  88. supported_models.update(
  89. {transformers.models.bloom.modeling_bloom.BloomModel})
  90. except Exception as e:
  91. print(
  92. f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}"
  93. )
  94. BLOOMLayerPolicy._orig_layer_class = None
  95. def get_hidden_heads(self):
  96. return self.client_module.self_attention.hidden_size, \
  97. self.client_module.self_attention.num_heads
  98. def attention(self):
  99. return self.client_module.self_attention.query_key_value.weight, \
  100. self.client_module.self_attention.query_key_value.bias, \
  101. self.client_module.self_attention.dense.weight, \
  102. self.client_module.self_attention.dense.bias,
  103. def mlp(self):
  104. return self.client_module.mlp.dense_h_to_4h.weight, \
  105. self.client_module.mlp.dense_h_to_4h.bias, \
  106. self.client_module.mlp.dense_4h_to_h.weight, \
  107. self.client_module.mlp.dense_4h_to_h.bias
  108. def layernorm(self):
  109. return self.client_module.post_attention_layernorm.weight, \
  110. self.client_module.post_attention_layernorm.bias, \
  111. self.client_module.input_layernorm.weight, \
  112. self.client_module.input_layernorm.bias