layers.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from deepspeed import comm as dist
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from torch.nn.parameter import Parameter
  7. from deepspeed.accelerator import get_accelerator
  8. class LinearAllreduce(nn.Module):
  9. def __init__(self, weight, bias=None, mp_group=None):
  10. super(LinearAllreduce, self).__init__()
  11. self.weight = weight
  12. self.bias = bias
  13. self.mp_group = mp_group
  14. def forward(self, input):
  15. output = torch.matmul(input, self.weight.transpose(-1, -2))
  16. if self.mp_group is not None:
  17. dist.all_reduce(output, group=self.mp_group)
  18. if self.bias is not None:
  19. output += self.bias
  20. return output
  21. class LinearLayer(nn.Module):
  22. def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
  23. super(LinearLayer, self).__init__()
  24. if weight is not None:
  25. self.weight = weight
  26. self.bias = bias
  27. else:
  28. self.weight = Parameter(
  29. torch.empty(weight_shape,
  30. dtype=dtype,
  31. device=get_accelerator().current_device_name()))
  32. self.bias = Parameter(
  33. torch.empty(weight_shape[0],
  34. dtype=dtype,
  35. device=get_accelerator().current_device_name())) \
  36. if bias is not None else None
  37. def forward(self, input):
  38. output = torch.matmul(input, self.weight.transpose(-1, -2))
  39. if self.bias is not None:
  40. output += self.bias
  41. return output
  42. class Normalize(nn.Module):
  43. def __init__(self, dim, dtype=torch.float, eps=1e-5):
  44. super(Normalize, self).__init__()
  45. self.norm = nn.LayerNorm(dim,
  46. eps=eps).to(dtype).to(
  47. get_accelerator().current_device_name())
  48. self.weight = self.norm.weight
  49. self.bias = self.norm.bias
  50. def forward(self, input):
  51. return self.norm(input)
  52. class EmbeddingLayer(nn.Module):
  53. def __init__(self, weight_shape, dtype=torch.half):
  54. super(EmbeddingLayer, self).__init__()
  55. self.weight = Parameter(
  56. torch.empty(weight_shape[0],
  57. weight_shape[1],
  58. dtype=dtype,
  59. device=get_accelerator().current_device_name()))
  60. def forward(self, input):
  61. return F.embedding(input, self.weight)
  62. class OPTEmbedding(EmbeddingLayer):
  63. """
  64. This module learns positional embeddings up to a fixed maximum size.
  65. """
  66. def __init__(self, weight_shape):
  67. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  68. # and adjust num_embeddings appropriately. Other models don't have this hack
  69. self.offset = 2
  70. super().__init__(weight_shape)
  71. def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
  72. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  73. attention_mask = attention_mask.long()
  74. # create positions depending on attention_mask
  75. positions = (torch.cumsum(attention_mask,
  76. dim=1).type_as(attention_mask) *
  77. attention_mask).long() - 1
  78. # cut positions if `past_key_values_length` is > 0
  79. positions = positions[:, past_key_values_length:]
  80. return super().forward(positions + self.offset)