optimized_linear.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import math
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from dataclasses import is_dataclass
  9. from deepspeed.accelerator import get_accelerator
  10. import deepspeed.comm as dist
  11. from .config import LoRAConfig, QuantizationConfig
  12. from .quantization import QuantizedParameter, QuantizedLinear
  13. class OptimizedLinear(nn.Module):
  14. """
  15. Optimized version of nn.Linear that adds features such as:
  16. * LoRA w. base weight sharding
  17. * FP [6,8,12] quantization
  18. Arguments:
  19. input_dim: Required: size of each input sample
  20. output_dim: Required: size of each output sample
  21. bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
  22. lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
  23. quantization_config: Optional: QuantizationConfig defining quantization features
  24. dtype: Optional: parameter dtype, only supports bfloat16 currently
  25. Returns:
  26. Returns a new nn.Module depending on the input config. Either native
  27. torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
  28. """
  29. def __new__(self,
  30. input_dim: int,
  31. output_dim: int,
  32. bias: bool = False,
  33. lora_config: LoRAConfig = None,
  34. quantization_config: QuantizationConfig = None,
  35. dtype=torch.bfloat16):
  36. if quantization_config is not None and not is_dataclass(quantization_config):
  37. raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
  38. if lora_config is not None and not is_dataclass(lora_config):
  39. raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
  40. if lora_config is None and quantization_config is None:
  41. # Everything disabled, fall back to normal nn.Linear
  42. self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)
  43. elif lora_config:
  44. # lora enabled, quantization may or may not be
  45. self = LoRAOptimizedLinear(input_dim=input_dim,
  46. output_dim=output_dim,
  47. bias=bias,
  48. lora_config=lora_config,
  49. quantization_config=quantization_config,
  50. dtype=dtype)
  51. elif quantization_config:
  52. # only quantization enabled, no lora
  53. self = QuantizedLinear(input_dim=input_dim,
  54. output_dim=output_dim,
  55. bias=bias,
  56. quantization_config=quantization_config,
  57. dtype=dtype)
  58. return self
  59. class LoRAOptimizedLinear(nn.Module):
  60. def __init__(self,
  61. input_dim: int,
  62. output_dim: int,
  63. bias: bool = False,
  64. lora_config: LoRAConfig = None,
  65. quantization_config: QuantizationConfig = None,
  66. device=None,
  67. dtype=torch.bfloat16):
  68. super().__init__()
  69. self.input_dim = input_dim
  70. self.output_dim = output_dim
  71. self.bias = bias
  72. self.lora_config = lora_config
  73. self.quantization_config = quantization_config
  74. device = get_accelerator().current_device_name() if device is None else device
  75. assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
  76. self.zero_shards = self.lora_config.base_weight_sharding
  77. self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
  78. w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
  79. torch.nn.init.xavier_uniform_(w)
  80. if self.quantization_config is not None:
  81. assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
  82. self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
  83. else:
  84. self.base_weight = w
  85. self.base_weight.requires_grad = False
  86. # Use RS lora for now.
  87. self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
  88. # Keeping lora weights in bf16 precision for ease of training.
  89. self.lora_weight_1 = nn.Linear(self.input_dim,
  90. self.lora_config.lora_r,
  91. bias=self.bias,
  92. device=device,
  93. dtype=dtype)
  94. self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
  95. self.output_dim,
  96. bias=self.bias,
  97. device=device,
  98. dtype=dtype)
  99. self.lora_weight_1.weight.requires_grad = True
  100. self.lora_weight_2.weight.requires_grad = True
  101. def full_weight(self):
  102. # This assumes weights are evenly sharded across gpus. which might not be correct.
  103. # in that case, we should flatten before all_gather.
  104. local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
  105. QuantizedParameter) else self.base_weight
  106. tensor_list = [
  107. torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
  108. for _ in range(self.zero_shards)
  109. ]
  110. dist.all_gather(tensor_list, local_weight)
  111. weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
  112. return weight
  113. def linear_without_F_linear(self, input, weight):
  114. output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
  115. output = output.view(*input.shape[:-1], weight.shape[1])
  116. return output
  117. def forward(self, input_tensor):
  118. # Gather the sharded base weight
  119. if self.zero_shards > 1:
  120. with torch.no_grad():
  121. base_weight = self.full_weight()
  122. elif self.quantization_config:
  123. base_weight = self.base_weight.dequantized()
  124. else:
  125. base_weight = self.base_weight
  126. base_weight_output = F.linear(input_tensor, base_weight)
  127. lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
  128. return base_weight_output + self.lora_scaling_factor * lora_output