model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import numpy as np
  5. import math, json, time, types, copy, sys, os
  6. import torch
  7. from torch.nn import functional as F
  8. import torch.nn as nn
  9. from transformers import PreTrainedTokenizerFast
  10. RUN_DEVICE = 'cpu' # cpu cuda
  11. ctx_len = 768
  12. n_layer = 12
  13. n_embd = 768
  14. # n_layer = 24
  15. # n_embd = 1024
  16. # ---> download RWKV-3 169M model from https://huggingface.co/BlinkDL/rwkv-3-pile-169m/tree/main
  17. # MODEL_NAME = '/data1/ckw/RWKV-3-Pile-430M-20220817-10602'
  18. MODEL_NAME = '/data1/ckw/RWKV-3-Pile-20220720-10704'
  19. K_EPS = 1e-8
  20. vocab_size = 50277
  21. VOCAB_NAME = '20B_tokenizer.json'
  22. print(f'\n* running on {RUN_DEVICE}')
  23. ################################################################################################################
  24. class RWKV_ChannelMix(nn.Module):
  25. def __init__(self, layer_id):
  26. super().__init__()
  27. self.layer_id = layer_id
  28. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  29. self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
  30. self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
  31. hidden_sz = 4 * n_embd
  32. self.key = nn.Linear(n_embd, hidden_sz, bias=False)
  33. self.receptance = nn.Linear(n_embd, n_embd, bias=False)
  34. self.value = nn.Linear(hidden_sz, n_embd, bias=False)
  35. def forward(self, x):
  36. xx = self.time_shift(x)
  37. xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
  38. xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
  39. k = self.key(xk)
  40. k = torch.square(torch.relu(k))
  41. kv = self.value(k)
  42. rkv = torch.sigmoid(self.receptance(xr)) * kv
  43. return rkv
  44. class RWKV_TimeMix(nn.Module):
  45. def __init__(self, layer_id):
  46. super().__init__()
  47. self.layer_id = layer_id
  48. self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
  49. self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
  50. self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
  51. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  52. self.time_mix_k = nn.Parameter(torch.ones(1,1,n_embd))
  53. self.time_mix_v = nn.Parameter(torch.ones(1,1,n_embd))
  54. self.time_mix_r = nn.Parameter(torch.ones(1,1,n_embd))
  55. self.key = nn.Linear(n_embd, n_embd, bias=False)
  56. self.value = nn.Linear(n_embd, n_embd, bias=False)
  57. self.receptance = nn.Linear(n_embd, n_embd, bias=False)
  58. self.output = nn.Linear(n_embd, n_embd, bias=False)
  59. def forward(self, x):
  60. B, T, C = x.size()
  61. xx = self.time_shift(x)
  62. xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
  63. xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
  64. xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
  65. k = self.key(xk).transpose(-1, -2)
  66. v = self.value(xv).transpose(-1, -2)
  67. r = self.receptance(xr)
  68. k = torch.clamp(k, max=60)
  69. k = torch.exp(k)
  70. kv = k * v
  71. self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
  72. w = torch.exp(self.time_w)
  73. w = w[:,-T:].unsqueeze(1)
  74. wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
  75. wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS
  76. rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
  77. rwkv = self.output(rwkv)
  78. return rwkv
  79. class Block(nn.Module):
  80. def __init__(self, layer_id):
  81. super().__init__()
  82. self.layer_id = layer_id
  83. self.ln1 = nn.LayerNorm(n_embd)
  84. self.ln2 = nn.LayerNorm(n_embd)
  85. if self.layer_id == 0:
  86. self.ln0 = nn.LayerNorm(n_embd)
  87. self.att = RWKV_TimeMix(layer_id)
  88. self.ffn = RWKV_ChannelMix(layer_id)
  89. def forward(self, x):
  90. if self.layer_id == 0:
  91. x = self.ln0(x)
  92. x = x + self.att(self.ln1(x))
  93. x = x + self.ffn(self.ln2(x))
  94. return x
  95. class RWKV_GPT(nn.Module):
  96. def __init__(self, MODEL_NAME=MODEL_NAME):
  97. super().__init__()
  98. print('\nloading RWKV-GPT', MODEL_NAME)
  99. self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
  100. self.emb = nn.Embedding(vocab_size, n_embd)
  101. self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
  102. self.ln_out = nn.LayerNorm(n_embd)
  103. self.head = nn.Linear(n_embd, vocab_size, bias=False)
  104. self.ctx_len = ctx_len
  105. self.eval()
  106. self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
  107. self.eval()
  108. def forward(self, idx):
  109. B, T = idx.size()
  110. assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
  111. x = self.emb(idx)
  112. x = self.blocks(x)
  113. x = self.ln_out(x)
  114. x = self.head(x)
  115. return x
  116. ################################################################################################################
  117. time_buf = {}
  118. class RWKV_RNN():
  119. def __init__(self, MODEL_NAME=MODEL_NAME):
  120. print('\nloading RWKV-RNN', MODEL_NAME)
  121. self.ctx_len = ctx_len
  122. self.n_layer = n_layer
  123. self.n_embd = n_embd
  124. self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
  125. self.w = types.SimpleNamespace()
  126. w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
  127. for x in w.keys():
  128. if '.time_' in x:
  129. w[x] = w[x].squeeze()
  130. if '.time_decay' in x:
  131. w[x] = torch.exp(-torch.exp(w[x]))
  132. if '.time_first' in x:
  133. w[x] = torch.exp(w[x])
  134. xx = x.split('.')
  135. here = self.w
  136. for i in range(len(xx)):
  137. if xx[i].isdigit():
  138. ii = int(xx[i])
  139. if ii not in here:
  140. here[ii] = types.SimpleNamespace()
  141. here = here[ii]
  142. else:
  143. if i == len(xx) - 1:
  144. setattr(here, xx[i], w[x])
  145. elif not hasattr(here, xx[i]):
  146. if xx[i+1].isdigit():
  147. setattr(here, xx[i], {})
  148. else:
  149. setattr(here, xx[i], types.SimpleNamespace())
  150. here = getattr(here, xx[i])
  151. self.clear()
  152. def clear(self):
  153. self.xx = {}
  154. self.aa = {}
  155. self.bb = {}
  156. def save(self, target):
  157. target.xx = copy.deepcopy(self.xx)
  158. target.aa = copy.deepcopy(self.aa)
  159. target.bb = copy.deepcopy(self.bb)
  160. def load(self, target):
  161. self.xx = copy.deepcopy(target.xx)
  162. self.aa = copy.deepcopy(target.aa)
  163. self.bb = copy.deepcopy(target.bb)
  164. def LN(self, xx, w):
  165. return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
  166. def FF(self, xx, w, name):
  167. if name not in self.xx:
  168. self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  169. xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
  170. xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
  171. self.xx[name] = xx
  172. r = torch.sigmoid(w.receptance.weight @ xr)
  173. k = torch.square(torch.relu(w.key.weight @ xk))
  174. kv = w.value.weight @ k
  175. return r * kv
  176. def SA(self, xx, w, name):
  177. if name not in self.xx:
  178. self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  179. self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  180. self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  181. xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
  182. xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
  183. xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
  184. self.xx[name] = xx
  185. r = torch.sigmoid(w.receptance.weight @ xr)
  186. k = torch.exp(torch.clamp(w.key.weight @ xk, max=60))
  187. v = w.value.weight @ xv
  188. kv = k * v
  189. a = self.aa[name] + w.time_first * kv
  190. b = self.bb[name] + w.time_first * k
  191. self.aa[name] = w.time_decay * self.aa[name] + kv
  192. self.bb[name] = w.time_decay * self.bb[name] + k
  193. rwkv = r * a / (b + K_EPS)
  194. return w.output.weight @ rwkv
  195. def run(self, ctx):
  196. w = self.w
  197. x = w.emb.weight[ctx[-1]]
  198. x = self.LN(x, w.blocks[0].ln0) #相比v2版本,增加了一个初始的归一化
  199. for i in range(n_layer):
  200. x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
  201. x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
  202. x = self.LN(x, w.ln_out)
  203. x = w.head.weight @ x
  204. x = x.tolist()
  205. return x
  206. ################################################################################################################