123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import math
- import torch.nn as nn
- import torch.nn.functional as F
- from dataclasses import is_dataclass
- from deepspeed.accelerator import get_accelerator
- import deepspeed.comm as dist
- from .config import LoRAConfig, QuantizationConfig
- from .quantization import QuantizedParameter, QuantizedLinear
- class OptimizedLinear(nn.Module):
- """
- Optimized version of nn.Linear that adds features such as:
- * LoRA w. base weight sharding
- * FP [6,8,12] quantization
- Arguments:
- input_dim: Required: size of each input sample
- output_dim: Required: size of each output sample
- bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
- lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
- quantization_config: Optional: QuantizationConfig defining quantization features
- dtype: Optional: parameter dtype, only supports bfloat16 currently
- Returns:
- Returns a new nn.Module depending on the input config. Either native
- torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
- """
- def __new__(self,
- input_dim: int,
- output_dim: int,
- bias: bool = False,
- lora_config: LoRAConfig = None,
- quantization_config: QuantizationConfig = None,
- dtype=torch.bfloat16):
- if quantization_config is not None and not is_dataclass(quantization_config):
- raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
- if lora_config is not None and not is_dataclass(lora_config):
- raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
- if lora_config is None and quantization_config is None:
- # Everything disabled, fall back to normal nn.Linear
- self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)
- elif lora_config:
- # lora enabled, quantization may or may not be
- self = LoRAOptimizedLinear(input_dim=input_dim,
- output_dim=output_dim,
- bias=bias,
- lora_config=lora_config,
- quantization_config=quantization_config,
- dtype=dtype)
- elif quantization_config:
- # only quantization enabled, no lora
- self = QuantizedLinear(input_dim=input_dim,
- output_dim=output_dim,
- bias=bias,
- quantization_config=quantization_config,
- dtype=dtype)
- return self
- class LoRAOptimizedLinear(nn.Module):
- def __init__(self,
- input_dim: int,
- output_dim: int,
- bias: bool = False,
- lora_config: LoRAConfig = None,
- quantization_config: QuantizationConfig = None,
- device=None,
- dtype=torch.bfloat16):
- super().__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.bias = bias
- self.lora_config = lora_config
- self.quantization_config = quantization_config
- device = get_accelerator().current_device_name() if device is None else device
- assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
- self.zero_shards = self.lora_config.base_weight_sharding
- self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
- w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
- torch.nn.init.xavier_uniform_(w)
- if self.quantization_config is not None:
- assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
- self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
- else:
- self.base_weight = w
- self.base_weight.requires_grad = False
- # Use RS lora for now.
- self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
- # Keeping lora weights in bf16 precision for ease of training.
- self.lora_weight_1 = nn.Linear(self.input_dim,
- self.lora_config.lora_r,
- bias=self.bias,
- device=device,
- dtype=dtype)
- self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
- self.output_dim,
- bias=self.bias,
- device=device,
- dtype=dtype)
- self.lora_weight_1.weight.requires_grad = True
- self.lora_weight_2.weight.requires_grad = True
- def full_weight(self):
- # This assumes weights are evenly sharded across gpus. which might not be correct.
- # in that case, we should flatten before all_gather.
- local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
- QuantizedParameter) else self.base_weight
- tensor_list = [
- torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
- for _ in range(self.zero_shards)
- ]
- dist.all_gather(tensor_list, local_weight)
- weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
- return weight
- def linear_without_F_linear(self, input, weight):
- output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
- output = output.view(*input.shape[:-1], weight.shape[1])
- return output
- def forward(self, input_tensor):
- # Gather the sharded base weight
- if self.zero_shards > 1:
- with torch.no_grad():
- base_weight = self.full_weight()
- elif self.quantization_config:
- base_weight = self.base_weight.dequantized()
- else:
- base_weight = self.base_weight
- base_weight_output = F.linear(input_tensor, base_weight)
- lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
- return base_weight_output + self.lora_scaling_factor * lora_output
|