layers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # ------------------------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
  4. # ------------------------------------------------------------------------------------------
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import math
  9. from typing import Optional, List
  10. class LoRALayer():
  11. def __init__(
  12. self,
  13. r: int,
  14. lora_alpha: int,
  15. lora_dropout: float,
  16. merge_weights: bool,
  17. ):
  18. self.r = r
  19. self.lora_alpha = lora_alpha
  20. # Optional dropout
  21. if lora_dropout > 0.:
  22. self.lora_dropout = nn.Dropout(p=lora_dropout)
  23. else:
  24. self.lora_dropout = lambda x: x
  25. # Mark the weight as unmerged
  26. self.merged = False
  27. self.merge_weights = merge_weights
  28. class LoRAEmbedding(nn.Embedding, LoRALayer):
  29. # LoRA implemented in a dense layer
  30. def __init__(
  31. self,
  32. num_embeddings: int,
  33. embedding_dim: int,
  34. r: int = 0,
  35. lora_alpha: int = 1,
  36. merge_weights: bool = True,
  37. **kwargs
  38. ):
  39. nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
  40. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
  41. merge_weights=merge_weights)
  42. # Actual trainable parameters
  43. if r > 0:
  44. self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
  45. self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
  46. self.scaling = self.lora_alpha / self.r
  47. # Freezing the pre-trained weight matrix
  48. self.weight.requires_grad = False
  49. self.reset_parameters()
  50. def reset_parameters(self):
  51. nn.Embedding.reset_parameters(self)
  52. if hasattr(self, 'lora_A'):
  53. # initialize A the same way as the default for nn.Linear and B to zero
  54. nn.init.zeros_(self.lora_A)
  55. nn.init.normal_(self.lora_B)
  56. def train(self, mode: bool = True):
  57. nn.Embedding.train(self, mode)
  58. if mode:
  59. if self.merge_weights and self.merged:
  60. # Make sure that the weights are not merged
  61. if self.r > 0:
  62. self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
  63. self.merged = False
  64. else:
  65. if self.merge_weights and not self.merged:
  66. # Merge the weights and mark it
  67. if self.r > 0:
  68. self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
  69. self.merged = True
  70. def forward(self, x: torch.Tensor):
  71. if self.r > 0 and not self.merged:
  72. result = nn.Embedding.forward(self, x)
  73. after_A = F.embedding(
  74. x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
  75. self.norm_type, self.scale_grad_by_freq, self.sparse
  76. )
  77. result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
  78. return result
  79. else:
  80. return nn.Embedding.forward(self, x)
  81. class LoRALinear(nn.Linear, LoRALayer):
  82. # LoRA implemented in a dense layer
  83. def __init__(
  84. self,
  85. in_features: int,
  86. out_features: int,
  87. r: int = 0,
  88. lora_alpha: int = 1,
  89. lora_dropout: float = 0.,
  90. fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
  91. merge_weights: bool = True,
  92. **kwargs
  93. ):
  94. nn.Linear.__init__(self, in_features, out_features, **kwargs)
  95. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
  96. merge_weights=merge_weights)
  97. self.fan_in_fan_out = fan_in_fan_out
  98. # Actual trainable parameters
  99. if r > 0:
  100. self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
  101. self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
  102. self.scaling = self.lora_alpha / self.r
  103. # Freezing the pre-trained weight matrix
  104. self.weight.requires_grad = False
  105. self.reset_parameters()
  106. if fan_in_fan_out:
  107. self.weight.data = self.weight.data.transpose(0, 1)
  108. def reset_parameters(self):
  109. nn.Linear.reset_parameters(self)
  110. if hasattr(self, 'lora_A'):
  111. # initialize B the same way as the default for nn.Linear and A to zero
  112. # this is different than what is described in the paper but should not affect performance
  113. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  114. nn.init.zeros_(self.lora_B)
  115. def train(self, mode: bool = True):
  116. def T(w):
  117. return w.transpose(0, 1) if self.fan_in_fan_out else w
  118. nn.Linear.train(self, mode)
  119. if mode:
  120. if self.merge_weights and self.merged:
  121. # Make sure that the weights are not merged
  122. if self.r > 0:
  123. self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
  124. self.merged = False
  125. else:
  126. if self.merge_weights and not self.merged:
  127. # Merge the weights and mark it
  128. if self.r > 0:
  129. self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
  130. self.merged = True
  131. def forward(self, x: torch.Tensor):
  132. def T(w):
  133. return w.transpose(0, 1) if self.fan_in_fan_out else w
  134. if self.r > 0 and not self.merged:
  135. result = F.linear(x, T(self.weight), bias=self.bias)
  136. result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
  137. return result
  138. else:
  139. return F.linear(x, T(self.weight), bias=self.bias)
  140. class MergedLoRALinear(nn.Linear, LoRALayer):
  141. # LoRA implemented in a dense layer
  142. def __init__(
  143. self,
  144. in_features: int,
  145. out_features: int,
  146. r: int = 0,
  147. lora_alpha: int = 1,
  148. lora_dropout: float = 0.,
  149. enable_lora: List[bool] = [False],
  150. fan_in_fan_out: bool = False,
  151. merge_weights: bool = True,
  152. **kwargs
  153. ):
  154. nn.Linear.__init__(self, in_features, out_features, **kwargs)
  155. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
  156. merge_weights=merge_weights)
  157. assert out_features % len(enable_lora) == 0, \
  158. 'The length of enable_lora must divide out_features'
  159. self.enable_lora = enable_lora
  160. self.fan_in_fan_out = fan_in_fan_out
  161. # Actual trainable parameters
  162. if r > 0 and any(enable_lora):
  163. self.lora_A = nn.Parameter(
  164. self.weight.new_zeros((r * sum(enable_lora), in_features)))
  165. self.lora_B = nn.Parameter(
  166. self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
  167. ) # weights for Conv1D with groups=sum(enable_lora)
  168. self.scaling = self.lora_alpha / self.r
  169. # Freezing the pre-trained weight matrix
  170. self.weight.requires_grad = False
  171. # Compute the indices
  172. self.lora_ind = self.weight.new_zeros(
  173. (out_features, ), dtype=torch.bool
  174. ).view(len(enable_lora), -1)
  175. self.lora_ind[enable_lora, :] = True
  176. self.lora_ind = self.lora_ind.view(-1)
  177. self.reset_parameters()
  178. if fan_in_fan_out:
  179. self.weight.data = self.weight.data.transpose(0, 1)
  180. def reset_parameters(self):
  181. nn.Linear.reset_parameters(self)
  182. if hasattr(self, 'lora_A'):
  183. # initialize A the same way as the default for nn.Linear and B to zero
  184. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  185. nn.init.zeros_(self.lora_B)
  186. def zero_pad(self, x):
  187. result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
  188. result[self.lora_ind] = x
  189. return result
  190. def merge_AB(self):
  191. def T(w):
  192. return w.transpose(0, 1) if self.fan_in_fan_out else w
  193. delta_w = F.conv1d(
  194. self.lora_A.unsqueeze(0),
  195. self.lora_B.unsqueeze(-1),
  196. groups=sum(self.enable_lora)
  197. ).squeeze(0)
  198. return T(self.zero_pad(delta_w))
  199. def train(self, mode: bool = True):
  200. def T(w):
  201. return w.transpose(0, 1) if self.fan_in_fan_out else w
  202. nn.Linear.train(self, mode)
  203. if mode:
  204. if self.merge_weights and self.merged:
  205. # Make sure that the weights are not merged
  206. if self.r > 0 and any(self.enable_lora):
  207. self.weight.data -= self.merge_AB() * self.scaling
  208. self.merged = False
  209. else:
  210. if self.merge_weights and not self.merged:
  211. # Merge the weights and mark it
  212. if self.r > 0 and any(self.enable_lora):
  213. self.weight.data += self.merge_AB() * self.scaling
  214. self.merged = True
  215. def forward(self, x: torch.Tensor):
  216. def T(w):
  217. return w.transpose(0, 1) if self.fan_in_fan_out else w
  218. if self.merged:
  219. return F.linear(x, T(self.weight), bias=self.bias)
  220. else:
  221. result = F.linear(x, T(self.weight), bias=self.bias)
  222. if self.r > 0:
  223. result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling
  224. return result
  225. class ConvLoRA(nn.Module, LoRALayer):
  226. def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
  227. super(ConvLoRA, self).__init__()
  228. self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
  229. self.weight = self.conv.weight
  230. self.bias = self.conv.bias
  231. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
  232. assert isinstance(kernel_size, int)
  233. # Actual trainable parameters
  234. if r > 0:
  235. self.lora_A = nn.Parameter(
  236. self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
  237. )
  238. self.lora_B = nn.Parameter(
  239. self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
  240. )
  241. self.scaling = self.lora_alpha / self.r
  242. # Freezing the pre-trained weight matrix
  243. self.conv.weight.requires_grad = False
  244. self.reset_parameters()
  245. self.merged = False
  246. def reset_parameters(self):
  247. self.conv.reset_parameters()
  248. if hasattr(self, 'lora_A'):
  249. # initialize A the same way as the default for nn.Linear and B to zero
  250. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
  251. nn.init.zeros_(self.lora_B)
  252. def train(self, mode=True):
  253. super(ConvLoRA, self).train(mode)
  254. if mode:
  255. if self.merge_weights and self.merged:
  256. if self.r > 0:
  257. # Make sure that the weights are not merged
  258. self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
  259. self.merged = False
  260. else:
  261. if self.merge_weights and not self.merged:
  262. if self.r > 0:
  263. # Merge the weights and mark it
  264. self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
  265. self.merged = True
  266. def forward(self, x):
  267. if self.r > 0 and not self.merged:
  268. return self.conv._conv_forward(
  269. x,
  270. self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
  271. self.conv.bias
  272. )
  273. return self.conv(x)
  274. class LoRAConv2d(ConvLoRA): # this is buggy, 对于2D可行,但是对于1d会多kernel_size倍,对3D会少kernel_size倍
  275. def __init__(self, *args, **kwargs):
  276. super(LoRAConv2d, self).__init__(nn.Conv2d, *args, **kwargs)
  277. class LoRAConv1d(ConvLoRA):
  278. def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
  279. super(ConvLoRA, self).__init__()
  280. self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
  281. self.weight = self.conv.weight
  282. self.bias = self.conv.bias
  283. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
  284. assert isinstance(kernel_size, int)
  285. # Actual trainable parameters
  286. if r > 0:
  287. self.lora_A = nn.Parameter(
  288. self.conv.weight.new_zeros((r * kernel_size, in_channels))
  289. )
  290. self.lora_B = nn.Parameter(
  291. self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
  292. )
  293. self.scaling = self.lora_alpha / self.r
  294. # Freezing the pre-trained weight matrix
  295. self.conv.weight.requires_grad = False
  296. self.reset_parameters()
  297. self.merged = False
  298. # Can Extend to other ones like this
  299. class LoRAConv3d(ConvLoRA):
  300. def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
  301. super(ConvLoRA, self).__init__()
  302. self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, **kwargs)
  303. self.weight = self.conv.weight
  304. self.bias = self.conv.bias
  305. LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
  306. assert isinstance(kernel_size, int)
  307. # Actual trainable parameters
  308. if r > 0:
  309. self.lora_A = nn.Parameter(
  310. self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size * kernel_size))
  311. )
  312. self.lora_B = nn.Parameter(
  313. self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
  314. )
  315. self.scaling = self.lora_alpha / self.r
  316. # Freezing the pre-trained weight matrix
  317. self.conv.weight.requires_grad = False
  318. self.reset_parameters()
  319. self.merged = False
  320. if __name__ == '__main__':
  321. conv = LoRAConv1d(3, 32, kernel_size=3, stride=1, r=8)
  322. print(" ")