12345678910111213141516171819202122232425262728293031 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from abc import ABC
- class MegatronContainer(ABC):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.megatron_v2 = self.policy.is_megatron_v2
- def _align_qkv_transposed(self, x):
- attention_head_size = x.shape[-1] // self.num_attention_heads
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
- x_1 = x.view(*new_x_shape)
- (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
- if len(q.shape) > 2:
- return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)),
- dim=-1).reshape(x.shape)
- else:
- return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
- def transpose(self):
- super().transpose()
- if self.megatron_v2:
- self.qkvw = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvw).contiguous())
- self.qkvb = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvb).contiguous())
|