123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- # ------------------------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
- # ------------------------------------------------------------------------------------------
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
- from typing import Optional, List
- class LoRALayer():
- def __init__(
- self,
- r: int,
- lora_alpha: int,
- lora_dropout: float,
- merge_weights: bool,
- ):
- self.r = r
- self.lora_alpha = lora_alpha
- # Optional dropout
- if lora_dropout > 0.:
- self.lora_dropout = nn.Dropout(p=lora_dropout)
- else:
- self.lora_dropout = lambda x: x
- # Mark the weight as unmerged
- self.merged = False
- self.merge_weights = merge_weights
- class LoRAEmbedding(nn.Embedding, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- r: int = 0,
- lora_alpha: int = 1,
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
- merge_weights=merge_weights)
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
- self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- def reset_parameters(self):
- nn.Embedding.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.zeros_(self.lora_A)
- nn.init.normal_(self.lora_B)
- def train(self, mode: bool = True):
- nn.Embedding.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
- self.merged = True
-
- def forward(self, x: torch.Tensor):
- if self.r > 0 and not self.merged:
- result = nn.Embedding.forward(self, x)
- after_A = F.embedding(
- x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse
- )
- result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
- return result
- else:
- return nn.Embedding.forward(self, x)
-
- class LoRALinear(nn.Linear, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
- merge_weights=merge_weights)
- self.fan_in_fan_out = fan_in_fan_out
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.transpose(0, 1)
- def reset_parameters(self):
- nn.Linear.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize B the same way as the default for nn.Linear and A to zero
- # this is different than what is described in the paper but should not affect performance
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def train(self, mode: bool = True):
- def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
- nn.Linear.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = True
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
- if self.r > 0 and not self.merged:
- result = F.linear(x, T(self.weight), bias=self.bias)
- result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
- return result
- else:
- return F.linear(x, T(self.weight), bias=self.bias)
- class MergedLoRALinear(nn.Linear, LoRALayer):
- # LoRA implemented in a dense layer
- def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- enable_lora: List[bool] = [False],
- fan_in_fan_out: bool = False,
- merge_weights: bool = True,
- **kwargs
- ):
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
- merge_weights=merge_weights)
- assert out_features % len(enable_lora) == 0, \
- 'The length of enable_lora must divide out_features'
- self.enable_lora = enable_lora
- self.fan_in_fan_out = fan_in_fan_out
- # Actual trainable parameters
- if r > 0 and any(enable_lora):
- self.lora_A = nn.Parameter(
- self.weight.new_zeros((r * sum(enable_lora), in_features)))
- self.lora_B = nn.Parameter(
- self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
- ) # weights for Conv1D with groups=sum(enable_lora)
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- # Compute the indices
- self.lora_ind = self.weight.new_zeros(
- (out_features, ), dtype=torch.bool
- ).view(len(enable_lora), -1)
- self.lora_ind[enable_lora, :] = True
- self.lora_ind = self.lora_ind.view(-1)
- self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.transpose(0, 1)
- def reset_parameters(self):
- nn.Linear.reset_parameters(self)
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def zero_pad(self, x):
- result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
- result[self.lora_ind] = x
- return result
- def merge_AB(self):
- def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
- delta_w = F.conv1d(
- self.lora_A.unsqueeze(0),
- self.lora_B.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).squeeze(0)
- return T(self.zero_pad(delta_w))
- def train(self, mode: bool = True):
- def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
- nn.Linear.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0 and any(self.enable_lora):
- self.weight.data -= self.merge_AB() * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0 and any(self.enable_lora):
- self.weight.data += self.merge_AB() * self.scaling
- self.merged = True
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
- if self.merged:
- return F.linear(x, T(self.weight), bias=self.bias)
- else:
- result = F.linear(x, T(self.weight), bias=self.bias)
- if self.r > 0:
- result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling
- return result
- class ConvLoRA(nn.Module, LoRALayer):
- def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
- super(ConvLoRA, self).__init__()
- self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
- self.weight = self.conv.weight
- self.bias = self.conv.bias
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
- assert isinstance(kernel_size, int)
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(
- self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
- )
- self.lora_B = nn.Parameter(
- self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
- )
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.conv.weight.requires_grad = False
- self.reset_parameters()
- self.merged = False
- def reset_parameters(self):
- self.conv.reset_parameters()
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
- def train(self, mode=True):
- super(ConvLoRA, self).train(mode)
- if mode:
- if self.merge_weights and self.merged:
- if self.r > 0:
- # Make sure that the weights are not merged
- self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- if self.r > 0:
- # Merge the weights and mark it
- self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
- self.merged = True
- def forward(self, x):
- if self.r > 0 and not self.merged:
- return self.conv._conv_forward(
- x,
- self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
- self.conv.bias
- )
- return self.conv(x)
- class LoRAConv2d(ConvLoRA): # this is buggy, 对于2D可行,但是对于1d会多kernel_size倍,对3D会少kernel_size倍
- def __init__(self, *args, **kwargs):
- super(LoRAConv2d, self).__init__(nn.Conv2d, *args, **kwargs)
- class LoRAConv1d(ConvLoRA):
- def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
- super(ConvLoRA, self).__init__()
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
- self.weight = self.conv.weight
- self.bias = self.conv.bias
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
- assert isinstance(kernel_size, int)
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(
- self.conv.weight.new_zeros((r * kernel_size, in_channels))
- )
- self.lora_B = nn.Parameter(
- self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
- )
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.conv.weight.requires_grad = False
- self.reset_parameters()
- self.merged = False
- # Can Extend to other ones like this
- class LoRAConv3d(ConvLoRA):
- def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
- super(ConvLoRA, self).__init__()
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, **kwargs)
- self.weight = self.conv.weight
- self.bias = self.conv.bias
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
- assert isinstance(kernel_size, int)
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(
- self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size * kernel_size))
- )
- self.lora_B = nn.Parameter(
- self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
- )
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.conv.weight.requires_grad = False
- self.reset_parameters()
- self.merged = False
- if __name__ == '__main__':
- conv = LoRAConv1d(3, 32, kernel_size=3, stride=1, r=8)
- print(" ")
|