fusedqkv_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.utils.logging import warning_once
  6. import re
  7. def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
  8. qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list]
  9. tp_fusedqkv_list = [
  10. torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0]))
  11. ]
  12. return tp_fusedqkv_list
  13. def require_tp_fused_qkvw(name, mp_size):
  14. fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv']
  15. if mp_size == 1:
  16. return False
  17. for fused_name in fused_qkvw_name_list:
  18. if fused_name in name:
  19. return True
  20. return False
  21. def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
  22. if src == None:
  23. return
  24. fused_type_dict = {
  25. 'CodeGenBlock': 'codegentype',
  26. 'BloomBlock': 'bloomtype',
  27. 'GLMBlock': 'glmtype',
  28. "MPTBlock": 'glmtype',
  29. "MptBlock": 'glmtype',
  30. }
  31. def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
  32. # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
  33. #TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
  34. #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)
  35. shape = input.shape
  36. dst_shape = shape[0] // mp_size
  37. num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])
  38. #num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
  39. src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
  40. src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]
  41. split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
  42. tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)
  43. return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
  44. def _glm_type_transpose(input, mp_size):
  45. #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)
  46. shape = input.shape
  47. dst_shape = shape[0] // mp_size
  48. src_split = torch.split(input, shape[0] // 3, dim=0)
  49. split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
  50. tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)
  51. return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
  52. def _bloom_type_transpose(input, mp_size):
  53. shape = input.shape
  54. dst_shape = shape[0] // mp_size
  55. return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
  56. def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
  57. # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
  58. # bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w]
  59. # glmtype: [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w]
  60. # codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file.
  61. if fused_qkv_type == 'bloomtype':
  62. return _bloom_type_transpose(src, mp_size)
  63. elif fused_qkv_type == 'codegentype':
  64. return _codegen_type_transpose(src, mp_size)
  65. elif fused_qkv_type == 'glmtype':
  66. return _glm_type_transpose(src, mp_size)
  67. raise ValueError("unknown fused_qkv_type")
  68. for module_name, fused_type in fused_type_dict.items():
  69. if re.search(module_name, module_str):
  70. return _transpose_fused_qkvw(src, mp_size, fused_type)
  71. warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
  72. f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
  73. return _bloom_type_transpose(src, mp_size)