hybrid_megatron.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from .hybrid_engine import HybridEngineContainer
  6. from .megatron import MegatronContainer
  7. class HybridMegatronContainer(MegatronContainer, HybridEngineContainer):
  8. def _align_qkv(self, x: torch.Tensor):
  9. """
  10. Internal helper for accepting the head-contiguous weight matrix and chunking
  11. the query, key, and value components.
  12. """
  13. attention_head_size = x.shape[0] // self.num_attention_heads
  14. new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
  15. x_1 = x.view(*new_x_shape)
  16. div_dim = len(x_1.size()) - 2 if len(x.shape) == 2 else -1
  17. (q, k, v) = torch.split(x_1, (x_1.shape[div_dim] // 3), dim=div_dim)
  18. if len(q.shape) > 2:
  19. x.data.copy_(
  20. torch.cat((q.reshape(-1, q.shape[-1]), k.reshape(-1, q.shape[-1]), v.reshape(-1, q.shape[-1])),
  21. dim=0).reshape(x.shape))
  22. else:
  23. x.data.copy_(torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape))
  24. def transform_for_inference(self) -> None:
  25. """
  26. Overrides the HybridEngineContainer implementation.
  27. The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V
  28. are sequential in memory. This is different from the default layout in which all of the Qs
  29. are sequential, followed by all of the Ks, and then all of the Vs. Here, we take the default
  30. layout and transform it to the inference layout.
  31. """
  32. if hasattr(self.qkvw, 'ds_id'):
  33. from deepspeed.runtime.zero import GatheredParameters
  34. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  35. param_list = [self.qkvw, self.qkvb]
  36. non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
  37. param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
  38. with GatheredParameters(non_active_params):
  39. self._align_qkv(self.qkvw)
  40. self._align_qkv(self.qkvb)
  41. else:
  42. self._align_qkv(self.qkvw)
  43. self._align_qkv(self.qkvb)
  44. def _partition_qkv(self, x: torch.Tensor):
  45. """
  46. Internal helper for taking contiguous QKV and partitioning it for contiguous
  47. heads.
  48. """
  49. q_k_v = torch.split(x, (x.shape[0] // 3), dim=0)
  50. attention_head_size = q_k_v[0].shape[0] // self.num_attention_heads
  51. new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
  52. q, k, v = [data.view(*new_x_shape) for data in q_k_v]
  53. if len(q.shape) > 2:
  54. x.data.copy_(torch.cat((q, k, v), dim=-2).reshape(-1, q.shape[-1]))
  55. else:
  56. x.data.copy_(torch.cat((q, k, v), dim=-1).reshape(-1))
  57. def transform_for_training(self):
  58. """
  59. Overrides the HybridEngineContainer implementation.
  60. The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V
  61. are sequential in memory. This is different from the default layout in which all of the Qs
  62. are sequential, followed by all of the Ks, and then all of the Vs. This function takes the inference format and reverts it back to the default format.
  63. """
  64. # If parameter is distributed, handle gathering it
  65. if hasattr(self.qkvw, 'ds_id'):
  66. from deepspeed.runtime.zero import GatheredParameters
  67. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  68. param_list = [self.qkvw, self.qkvb]
  69. non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
  70. param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
  71. with GatheredParameters(non_active_params):
  72. self._partition_qkv(self.qkvw)
  73. self._partition_qkv(self.qkvb)
  74. else:
  75. self._partition_qkv(self.qkvw)
  76. self._partition_qkv(self.qkvb)