123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed import comm as dist
- from torch import nn
- from torch.nn import functional as F
- from torch.nn.parameter import Parameter
- from deepspeed.accelerator import get_accelerator
- class LinearAllreduce(nn.Module):
- def __init__(self, weight, bias=None, mp_group=None):
- super(LinearAllreduce, self).__init__()
- self.weight = weight
- self.bias = bias
- self.mp_group = mp_group
- def forward(self, input):
- output = torch.matmul(input, self.weight.transpose(-1, -2))
- if self.mp_group is not None:
- dist.inference_all_reduce(output, group=self.mp_group)
- if self.bias is not None:
- output += self.bias
- return output
- class LinearLayer(nn.Module):
- def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
- super(LinearLayer, self).__init__()
- if weight is not None:
- self.weight = weight
- self.bias = bias
- else:
- self.weight = Parameter(
- torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name()))
- self.bias = Parameter(
- torch.empty(weight_shape[0],
- dtype=dtype,
- device=get_accelerator().current_device_name())) \
- if bias is not None else None
- def forward(self, input):
- output = torch.matmul(input, self.weight.transpose(-1, -2))
- if self.bias is not None:
- output += self.bias
- return output
- class Normalize(nn.Module):
- def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
- super(Normalize, self).__init__()
- if weight is not None:
- self.weight = weight
- self.bias = bias
- else:
- self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
- self.weight = self.norm.weight
- self.bias = self.norm.bias
- self.eps = eps
- def forward(self, input):
- return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
- class EmbeddingLayer(nn.Module):
- def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
- super(EmbeddingLayer, self).__init__()
- if weight is None:
- self.weight = Parameter(
- torch.empty(weight_shape[0],
- weight_shape[1],
- dtype=dtype,
- device=get_accelerator().current_device_name()))
- else:
- self.weight = weight
- def forward(self, input):
- return F.embedding(input, self.weight)
- class OPTEmbedding(EmbeddingLayer):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
- def __init__(self, weight_shape=None, weight=None, bias=None):
- # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
- # and adjust num_embeddings appropriately. Other models don't have this hack
- self.offset = 2
- super().__init__(weight_shape, weight=weight)
- def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
- attention_mask = attention_mask.long()
- # create positions depending on attention_mask
- positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
- # cut positions if `past_key_values_length` is > 0
- positions = positions[:, past_key_values_length:]
- return super().forward(positions + self.offset)
- class RMSNormalize(nn.Module):
- def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
- super(RMSNormalize, self).__init__()
- if weight is not None:
- self.weight = weight
- else:
- self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=get_accelerator().current_device_name()))
- self.eps = eps
- def forward(self, hidden_states):
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return hidden_states * self.weight
|