layers.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed import comm as dist
  6. from torch import nn
  7. from torch.nn import functional as F
  8. from torch.nn.parameter import Parameter
  9. from deepspeed.accelerator import get_accelerator
  10. class LinearAllreduce(nn.Module):
  11. def __init__(self, weight, bias=None, mp_group=None):
  12. super(LinearAllreduce, self).__init__()
  13. self.weight = weight
  14. self.bias = bias
  15. self.mp_group = mp_group
  16. def forward(self, input):
  17. output = torch.matmul(input, self.weight.transpose(-1, -2))
  18. if self.mp_group is not None:
  19. dist.inference_all_reduce(output, group=self.mp_group)
  20. if self.bias is not None:
  21. output += self.bias
  22. return output
  23. class LinearLayer(nn.Module):
  24. def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
  25. super(LinearLayer, self).__init__()
  26. if weight is not None:
  27. self.weight = weight
  28. self.bias = bias
  29. else:
  30. self.weight = Parameter(
  31. torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name()))
  32. self.bias = Parameter(
  33. torch.empty(weight_shape[0],
  34. dtype=dtype,
  35. device=get_accelerator().current_device_name())) \
  36. if bias is not None else None
  37. def forward(self, input):
  38. output = torch.matmul(input, self.weight.transpose(-1, -2))
  39. if self.bias is not None:
  40. output += self.bias
  41. return output
  42. class Normalize(nn.Module):
  43. def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
  44. super(Normalize, self).__init__()
  45. if weight is not None:
  46. self.weight = weight
  47. self.bias = bias
  48. else:
  49. self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
  50. self.weight = self.norm.weight
  51. self.bias = self.norm.bias
  52. self.eps = eps
  53. def forward(self, input):
  54. return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
  55. class EmbeddingLayer(nn.Module):
  56. def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
  57. super(EmbeddingLayer, self).__init__()
  58. if weight is None:
  59. self.weight = Parameter(
  60. torch.empty(weight_shape[0],
  61. weight_shape[1],
  62. dtype=dtype,
  63. device=get_accelerator().current_device_name()))
  64. else:
  65. self.weight = weight
  66. def forward(self, input):
  67. return F.embedding(input, self.weight)
  68. class OPTEmbedding(EmbeddingLayer):
  69. """
  70. This module learns positional embeddings up to a fixed maximum size.
  71. """
  72. def __init__(self, weight_shape=None, weight=None, bias=None):
  73. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  74. # and adjust num_embeddings appropriately. Other models don't have this hack
  75. self.offset = 2
  76. super().__init__(weight_shape, weight=weight)
  77. def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
  78. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  79. attention_mask = attention_mask.long()
  80. # create positions depending on attention_mask
  81. positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
  82. # cut positions if `past_key_values_length` is > 0
  83. positions = positions[:, past_key_values_length:]
  84. return super().forward(positions + self.offset)
  85. class RMSNormalize(nn.Module):
  86. def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
  87. super(RMSNormalize, self).__init__()
  88. if weight is not None:
  89. self.weight = weight
  90. else:
  91. self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=get_accelerator().current_device_name()))
  92. self.eps = eps
  93. def forward(self, hidden_states):
  94. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  95. hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
  96. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  97. hidden_states = hidden_states.to(self.weight.dtype)
  98. return hidden_states * self.weight