bloom.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from .features.meta_tensor import MetaTensorContainer
  6. from .features.hybrid_engine import HybridEngineContainer
  7. from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
  8. from ..policy import TransformerPolicy
  9. from ..policy import transformer_param_names
  10. from ..policy import maybe_copy
  11. from ..policy import maybe_get_lora
  12. supported_models = {None}
  13. class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. # All model specific things should be defined here instead of the base class.
  17. self.bigscience_bloom = True
  18. self.triangular_masking = False
  19. def create_module(self, config=None):
  20. _config = config if config is not None else self.ds_model_config
  21. self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
  22. self.module.config.scale_attention = self.scale_attention
  23. self.module.config.invert_mask = False
  24. return self.module
  25. def attention_qkv_mp(self, mp_replace, reversed_dim=False):
  26. self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
  27. self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
  28. def get_lora_matched_pair(self):
  29. """
  30. Necessary to implement for `HybridEngineContainer`
  31. """
  32. fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params()
  33. ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)]
  34. return ret
  35. def set_lora_params(self):
  36. """
  37. Necessary to implement for `HybridEngineContainer`
  38. """
  39. self.lora_params = [
  40. maybe_get_lora(p) for p in [
  41. self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, self.policy.
  42. client_module.self_attention.query_key_value, self.policy.client_module.self_attention.dense
  43. ]
  44. ]
  45. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  46. param_names = (
  47. 'self_attention.query_key_value.weight', \
  48. 'self_attention.query_key_value.bias', \
  49. 'self_attention.dense.weight', \
  50. 'self_attention.dense.bias', \
  51. 'mlp.dense_h_to_4h.weight', \
  52. 'mlp.dense_h_to_4h.bias', \
  53. 'mlp.dense_4h_to_h.weight', \
  54. 'mlp.dense_4h_to_h.bias', \
  55. 'post_attention_layernorm.weight', \
  56. 'post_attention_layernorm.bias', \
  57. 'input_layernorm.weight', \
  58. 'input_layernorm.bias'
  59. )
  60. for i in range(0, 2):
  61. maybe_copy(module.attention,
  62. sd,
  63. weight_quantizer,
  64. mp_replace,
  65. transformer_param_names[i],
  66. prefix + param_names[i],
  67. qkv=True,
  68. megatron_v2=self.policy.is_megatron_v2,
  69. split_qkv=self.policy.split_qkv)
  70. for i in range(2, 4):
  71. maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
  72. prefix + param_names[i])
  73. for i in range(4, 10):
  74. maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
  75. prefix + param_names[i])
  76. for i in range(10, 12):
  77. maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
  78. class BLOOMLayerPolicy(TransformerPolicy):
  79. _orig_layer_class = None
  80. def __init__(self, client_module, inference=True, use_load_prefix=True, split_qkv=False):
  81. super().__init__(inference, linear_layer=True, use_load_prefix=use_load_prefix, split_qkv=split_qkv)
  82. self.client_module = client_module
  83. try:
  84. import transformers
  85. BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
  86. global supported_models
  87. supported_models.update({transformers.models.bloom.modeling_bloom.BloomModel})
  88. except Exception as e:
  89. print(f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}")
  90. BLOOMLayerPolicy._orig_layer_class = None
  91. def get_hidden_heads(self):
  92. return self.client_module.self_attention.hidden_size, \
  93. self.client_module.self_attention.num_heads, \
  94. self.client_module.input_layernorm.eps, \
  95. DEFAULT_INTERMEDIATE_SIZE
  96. def attention(self, enable_training=False):
  97. return self.client_module.self_attention.query_key_value.weight, \
  98. self.client_module.self_attention.query_key_value.bias, \
  99. self.client_module.self_attention.dense.weight, \
  100. self.client_module.self_attention.dense.bias,
  101. def mlp(self, enable_training=False):
  102. return self.client_module.mlp.dense_h_to_4h.weight, \
  103. self.client_module.mlp.dense_h_to_4h.bias, \
  104. self.client_module.mlp.dense_4h_to_h.weight, \
  105. self.client_module.mlp.dense_4h_to_h.bias
  106. def layernorm(self):
  107. return self.client_module.post_attention_layernorm.weight, \
  108. self.client_module.post_attention_layernorm.bias, \
  109. self.client_module.input_layernorm.weight, \
  110. self.client_module.input_layernorm.bias