model_v2.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import numpy as np
  5. import math, json, time, types, copy, sys, os
  6. import torch
  7. from torch.nn import functional as F
  8. import torch.nn as nn
  9. from transformers import PreTrainedTokenizerFast
  10. # RUN_DEVICE = 'cpu' # cpu cuda
  11. # ctx_len = 768
  12. # n_layer = 12
  13. # n_embd = 768
  14. RUN_DEVICE = 'cpu'
  15. ctx_len = 768
  16. n_layer = 24
  17. n_embd = 1024
  18. MODEL_NAME = '/data1/ckw/20220615-10803'
  19. vocab_size = 50277
  20. VOCAB_NAME = '20B_tokenizer.json'
  21. print(f'\n* running on {RUN_DEVICE}')
  22. ################################################################################################################
  23. class RWKV_ChannelMix(nn.Module):
  24. def __init__(self, layer_id):
  25. super().__init__()
  26. self.layer_id = layer_id
  27. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  28. self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
  29. hidden_sz = 4 * n_embd
  30. self.key = nn.Linear(n_embd, hidden_sz, bias=False)
  31. self.receptance = nn.Linear(n_embd, n_embd, bias=False)
  32. self.value = nn.Linear(hidden_sz, n_embd, bias=False)
  33. def forward(self, x):
  34. x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
  35. k = self.key(x)
  36. k = torch.square(torch.relu(k))
  37. kv = self.value(k)
  38. rkv = torch.sigmoid(self.receptance(x)) * kv
  39. return rkv
  40. class RWKV_TimeMix(nn.Module):
  41. def __init__(self, layer_id):
  42. super().__init__()
  43. self.layer_id = layer_id
  44. self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
  45. self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
  46. self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
  47. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  48. self.time_mix = nn.Parameter(torch.ones(1,1,n_embd))
  49. self.key = nn.Linear(n_embd, n_embd, bias=False)
  50. self.value = nn.Linear(n_embd, n_embd, bias=False)
  51. self.receptance = nn.Linear(n_embd, n_embd, bias=False)
  52. self.output = nn.Linear(n_embd, n_embd, bias=False)
  53. def forward(self, x):
  54. B, T, C = x.size()
  55. x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
  56. k = self.key(x).transpose(-1, -2)
  57. v = self.value(x).transpose(-1, -2)
  58. r = self.receptance(x)
  59. k = torch.clamp(k, max=60)
  60. k = torch.exp(k)
  61. kv = k * v
  62. self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
  63. w = torch.exp(self.time_w)
  64. w = w[:,-T:].unsqueeze(1)
  65. wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
  66. wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + 1e-9
  67. rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
  68. rwkv = self.output(rwkv)
  69. return rwkv
  70. class Block(nn.Module):
  71. def __init__(self, layer_id):
  72. super().__init__()
  73. self.layer_id = layer_id
  74. self.ln1 = nn.LayerNorm(n_embd)
  75. self.ln2 = nn.LayerNorm(n_embd)
  76. self.att = RWKV_TimeMix(layer_id)
  77. self.ffn = RWKV_ChannelMix(layer_id)
  78. def forward(self, x):
  79. x = self.ln1(x)
  80. x = x + self.att(x)
  81. x = self.ln2(x)
  82. x = x + self.ffn(x)
  83. return x
  84. class RWKV_GPT(nn.Module):
  85. def __init__(self, MODEL_NAME=MODEL_NAME):
  86. super().__init__()
  87. print('\nloading RWKV-GPT', MODEL_NAME)
  88. self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
  89. self.emb = nn.Embedding(vocab_size, n_embd)
  90. self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
  91. self.ln_out = nn.LayerNorm(n_embd)
  92. self.head = nn.Linear(n_embd, vocab_size, bias=False)
  93. self.ctx_len = ctx_len
  94. self.eval()
  95. self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
  96. self.eval()
  97. def forward(self, idx):
  98. B, T = idx.size()
  99. assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
  100. x = self.emb(idx)
  101. x = self.blocks(x)
  102. x = self.ln_out(x)
  103. x = self.head(x)
  104. return x
  105. ################################################################################################################
  106. time_buf = {}
  107. class RWKV_RNN():
  108. def __init__(self, MODEL_NAME=MODEL_NAME):
  109. print('\nloading RWKV-RNN', MODEL_NAME)
  110. self.ctx_len = ctx_len
  111. self.n_layer = n_layer
  112. self.n_embd = n_embd
  113. self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)
  114. self.w = types.SimpleNamespace()
  115. w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))
  116. for x in w.keys():
  117. if '.time_' in x:
  118. w[x] = w[x].squeeze()
  119. if '.time_decay' in x:
  120. w[x] = torch.exp(-torch.exp(w[x]))
  121. if '.time_first' in x:
  122. w[x] = torch.exp(w[x])
  123. xx = x.split('.')
  124. here = self.w
  125. for i in range(len(xx)):
  126. if xx[i].isdigit():
  127. ii = int(xx[i])
  128. if ii not in here:
  129. here[ii] = types.SimpleNamespace()
  130. here = here[ii]
  131. else:
  132. if i == len(xx) - 1:
  133. setattr(here, xx[i], w[x])
  134. elif not hasattr(here, xx[i]):
  135. if xx[i+1].isdigit():
  136. setattr(here, xx[i], {})
  137. else:
  138. setattr(here, xx[i], types.SimpleNamespace())
  139. here = getattr(here, xx[i])
  140. self.clear()
  141. def clear(self):
  142. self.xx = {}
  143. self.aa = {}
  144. self.bb = {}
  145. def save(self, target):
  146. target.xx = copy.deepcopy(self.xx)
  147. target.aa = copy.deepcopy(self.aa)
  148. target.bb = copy.deepcopy(self.bb)
  149. def load(self, target):
  150. self.xx = copy.deepcopy(target.xx)
  151. self.aa = copy.deepcopy(target.aa)
  152. self.bb = copy.deepcopy(target.bb)
  153. def LN(self, xx, w):
  154. return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)
  155. def FF(self, xx, w, name):
  156. if name not in self.xx:
  157. self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  158. x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
  159. self.xx[name] = xx
  160. r = torch.sigmoid(w.receptance.weight @ x)
  161. k = torch.square(torch.relu(w.key.weight @ x))
  162. kv = w.value.weight @ k
  163. return r * kv
  164. def SA(self, xx, w, name):
  165. if name not in self.xx:
  166. self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  167. self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  168. self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
  169. x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
  170. self.xx[name] = xx
  171. r = torch.sigmoid(w.receptance.weight @ x)
  172. k = torch.exp(torch.clamp(w.key.weight @ x, max=60))
  173. v = w.value.weight @ x
  174. kv = k * v
  175. a = self.aa[name] + w.time_first * kv
  176. b = self.bb[name] + w.time_first * k
  177. self.aa[name] = w.time_decay * self.aa[name] + kv
  178. self.bb[name] = w.time_decay * self.bb[name] + k
  179. rwkv = r * a / (b + 1e-9)
  180. return w.output.weight @ rwkv
  181. def run(self, ctx):
  182. w = self.w
  183. x = w.emb.weight[ctx[-1]]
  184. for i in range(n_layer):
  185. x = self.LN(x, w.blocks[i].ln1)
  186. x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
  187. x = self.LN(x, w.blocks[i].ln2)
  188. x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
  189. x = self.LN(x, w.ln_out)
  190. x = w.head.weight @ x
  191. x = x.tolist()
  192. return x
  193. ################################################################################################################