split_qkv.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import abstractmethod
  5. import torch
  6. from .hybrid_engine import HybridEngineContainer
  7. class HybridSplitQKVContainer(HybridEngineContainer):
  8. def set_attention(self, qkvw, qkvb, dense_w, dense_b):
  9. super().set_attention(qkvw, qkvb, dense_w, dense_b)
  10. self.set_q_k_v()
  11. @abstractmethod
  12. def set_q_k_v(self):
  13. """
  14. In `set_q_k_v`, it is necessary to populate the following variables (where appropriate)
  15. for the given model:
  16. self.qw: q weight
  17. self.qb: q bias
  18. self.kw: k weight
  19. self.kb: k bias
  20. self.vw: v weight
  21. self.vb: v bias
  22. """
  23. raise NotImplementedError("A set_q_k_v() function must be defined in the model container \
  24. in order to set the unfused q, k, and v tensors.")
  25. def attention_qkv_mp(self, mp_replace, reversed_dim=False):
  26. # Only need to alter
  27. if self.module.attention.attn_qkvw is None:
  28. params = [
  29. (self.module.attention.attn_qw, self.qw),
  30. (self.module.attention.attn_qb, self.qb),
  31. (self.module.attention.attn_kw, self.kw),
  32. (self.module.attention.attn_kb, self.kb),
  33. (self.module.attention.attn_vw, self.vw),
  34. (self.module.attention.attn_vb, self.vb),
  35. ]
  36. for dst, src in params:
  37. dst = mp_replace.copy(
  38. dst[:self.qw.shape[0] // mp_replace.mp_size], src, int8=reversed_dim,
  39. allocate_tensor=reversed_dim) if src is not None else None
  40. else:
  41. super().attention_qkv_mp(mp_replace)
  42. def release_qkv(self):
  43. super().release_qkv()
  44. split_qkv_params = [
  45. (self.module.attention.attn_qw, self.qw),
  46. (self.module.attention.attn_qb, self.qb),
  47. (self.module.attention.attn_kw, self.kw),
  48. (self.module.attention.attn_kb, self.kb),
  49. (self.module.attention.attn_vw, self.vw),
  50. (self.module.attention.attn_vb, self.vb),
  51. ]
  52. self._release_params(split_qkv_params)
  53. def reset_qkv(self):
  54. self.qkvw.data[:self.qw.shape[0]] = self.qw.data
  55. self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
  56. self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
  57. qkv_data = [self.qw.data, self.kw.data, self.vw.data]
  58. self.qw.data = self.qkvw.data[:self.qw.shape[0]]
  59. self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
  60. self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:]
  61. if self.qkvb is not None:
  62. self.qkvb.data[:self.qw.shape[0]] = self.qb.data
  63. self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
  64. self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
  65. qkv_data.extend([self.qb.data, self.kb.data, self.vb.data])
  66. self.qb.data = self.qkvb.data[:self.qw.shape[0]]
  67. self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
  68. self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:]
  69. for data in qkv_data:
  70. del data
  71. def reset_qkv_experimental(self):
  72. """
  73. WIP - experimental and likely to be changed/improved.
  74. Unused by keeping for now.
  75. """
  76. if self.module.attention.attn_qkvw is None:
  77. self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3,
  78. self.qw.shape[0],
  79. dtype=self.qw.dtype,
  80. device=self.qw.device)
  81. self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3,
  82. dtype=self.qw.dtype,
  83. device=self.qw.device)
  84. self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data
  85. self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data
  86. self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
  87. self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
  88. self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
  89. self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
  90. qkv_data = [self.qw.data, \
  91. self.qb.data, \
  92. self.kw.data, \
  93. self.kb.data, \
  94. self.vw.data, \
  95. self.vb.data]
  96. self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]]
  97. self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]]
  98. self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
  99. self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
  100. self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:]
  101. self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:]
  102. for data in qkv_data:
  103. del data
  104. def set_attn_params_wo_copy(self, Z3_enabled=False):
  105. self.module.attention.attn_ow = self.dense_w
  106. self.module.attention.attn_ob = self.dense_b
  107. if not Z3_enabled:
  108. # In initialize_tensors, we create a fused qkvw with the appropriate shape
  109. # and copy the qw, qb, kw, kb, vw, vb into it
  110. self.module.attention.attn_qkvw = self.qkvw
  111. self.module.attention.attn_qkvb = self.qkvb
  112. # We reset the data for qw (which is the original model parameter) to point
  113. # to the fused weight matrix we have created here
  114. self.qw.data = self.qkvw[:self.qw.shape[0], :]
  115. self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :]
  116. self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :]
  117. # Assume if one of the biases is not None, then all of them are not None
  118. if self.qb is not None:
  119. self.qb.data = self.qkvb[:self.qw.shape[0]]
  120. self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]]
  121. self.vb.data = self.qkvb[self.qw.shape[0] * 2:]
  122. else:
  123. # In ZeRO-3 this will be managed by ZeRO and handled separately in the
  124. # forward of ds_attention
  125. self.module.attention.attn_qw = self.qw
  126. self.module.attention.attn_qb = self.qb
  127. self.module.attention.attn_kw = self.kw
  128. self.module.attention.attn_kb = self.kb
  129. self.module.attention.attn_vw = self.vw
  130. self.module.attention.attn_vb = self.vb
  131. def get_attn_params(self):
  132. params = super().get_attn_params()
  133. params.extend([self.qw, self.qb, self.kw, self.kb, self.vw, self.vb])
  134. return params