megatron.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from abc import ABC
  6. class MegatronContainer(ABC):
  7. def __init__(self, **kwargs):
  8. super().__init__(**kwargs)
  9. self.megatron_v2 = self.policy.is_megatron_v2
  10. def _align_qkv_transposed(self, x):
  11. attention_head_size = x.shape[-1] // self.num_attention_heads
  12. new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
  13. x_1 = x.view(*new_x_shape)
  14. (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
  15. if len(q.shape) > 2:
  16. return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)),
  17. dim=-1).reshape(x.shape)
  18. else:
  19. return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
  20. def transpose(self):
  21. super().transpose()
  22. if self.megatron_v2:
  23. self.qkvw = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvw).contiguous())
  24. self.qkvb = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvb).contiguous())