123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from .hybrid_engine import HybridEngineContainer
- from .megatron import MegatronContainer
- class HybridMegatronContainer(MegatronContainer, HybridEngineContainer):
- def _align_qkv(self, x: torch.Tensor):
- """
- Internal helper for accepting the head-contiguous weight matrix and chunking
- the query, key, and value components.
- """
- attention_head_size = x.shape[0] // self.num_attention_heads
- new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
- x_1 = x.view(*new_x_shape)
- div_dim = len(x_1.size()) - 2 if len(x.shape) == 2 else -1
- (q, k, v) = torch.split(x_1, (x_1.shape[div_dim] // 3), dim=div_dim)
- if len(q.shape) > 2:
- x.data.copy_(
- torch.cat((q.reshape(-1, q.shape[-1]), k.reshape(-1, q.shape[-1]), v.reshape(-1, q.shape[-1])),
- dim=0).reshape(x.shape))
- else:
- x.data.copy_(torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape))
- def transform_for_inference(self) -> None:
- """
- Overrides the HybridEngineContainer implementation.
- The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V
- are sequential in memory. This is different from the default layout in which all of the Qs
- are sequential, followed by all of the Ks, and then all of the Vs. Here, we take the default
- layout and transform it to the inference layout.
- """
- if hasattr(self.qkvw, 'ds_id'):
- from deepspeed.runtime.zero import GatheredParameters
- from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
- param_list = [self.qkvw, self.qkvb]
- non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
- param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
- with GatheredParameters(non_active_params):
- self._align_qkv(self.qkvw)
- self._align_qkv(self.qkvb)
- else:
- self._align_qkv(self.qkvw)
- self._align_qkv(self.qkvb)
- def _partition_qkv(self, x: torch.Tensor):
- """
- Internal helper for taking contiguous QKV and partitioning it for contiguous
- heads.
- """
- q_k_v = torch.split(x, (x.shape[0] // 3), dim=0)
- attention_head_size = q_k_v[0].shape[0] // self.num_attention_heads
- new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
- q, k, v = [data.view(*new_x_shape) for data in q_k_v]
- if len(q.shape) > 2:
- x.data.copy_(torch.cat((q, k, v), dim=-2).reshape(-1, q.shape[-1]))
- else:
- x.data.copy_(torch.cat((q, k, v), dim=-1).reshape(-1))
- def transform_for_training(self):
- """
- Overrides the HybridEngineContainer implementation.
- The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V
- are sequential in memory. This is different from the default layout in which all of the Qs
- are sequential, followed by all of the Ks, and then all of the Vs. This function takes the inference format and reverts it back to the default format.
- """
- # If parameter is distributed, handle gathering it
- if hasattr(self.qkvw, 'ds_id'):
- from deepspeed.runtime.zero import GatheredParameters
- from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
- param_list = [self.qkvw, self.qkvb]
- non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
- param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
- with GatheredParameters(non_active_params):
- self._partition_qkv(self.qkvw)
- self._partition_qkv(self.qkvb)
- else:
- self._partition_qkv(self.qkvw)
- self._partition_qkv(self.qkvb)
|