model_v5.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. ########################################################################################################
  5. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  6. ########################################################################################################
  7. import numpy as np
  8. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  9. import types, torch
  10. import torch.nn as nn
  11. from torch.nn import functional as F
  12. MyModule = torch.jit.ScriptModule
  13. MyFunction = torch.jit.script_method
  14. # In[2]:
  15. import torch
  16. # In[3]:
  17. class RWKV_TOKENIZER():
  18. table: list[list[list[bytes]]]
  19. good: list[set[int]]
  20. wlen: list[int]
  21. def __init__(self, file_name):
  22. self.idx2token = {}
  23. sorted = [] # must be already sorted
  24. lines = open(file_name, "r", encoding="utf-8").readlines()
  25. for l in lines:
  26. idx = int(l[:l.index(' ')])
  27. x = eval(l[l.index(' '):l.rindex(' ')])
  28. x = x.encode("utf-8") if isinstance(x, str) else x
  29. assert isinstance(x, bytes)
  30. assert len(x) == int(l[l.rindex(' '):])
  31. sorted += [x]
  32. self.idx2token[idx] = x
  33. self.token2idx = {}
  34. for k, v in self.idx2token.items():
  35. self.token2idx[v] = int(k)
  36. # precompute some tables for fast matching
  37. self.table = [[[] for j in range(256)] for i in range(256)]
  38. self.good = [set() for i in range(256)]
  39. self.wlen = [0 for i in range(256)]
  40. for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
  41. s = sorted[i]
  42. if len(s) >= 2:
  43. s0 = int(s[0])
  44. s1 = int(s[1])
  45. self.table[s0][s1] += [s]
  46. self.wlen[s0] = max(self.wlen[s0], len(s))
  47. self.good[s0].add(s1)
  48. def encodeBytes(self, src: bytes) -> list[int]:
  49. src_len: int = len(src)
  50. tokens: list[int] = []
  51. i: int = 0
  52. while i < src_len:
  53. s: bytes = src[i : i + 1]
  54. if i < src_len - 1:
  55. s1: int = int(src[i + 1])
  56. s0: int = int(src[i])
  57. if s1 in self.good[s0]:
  58. sss: bytes = src[i : i + self.wlen[s0]]
  59. try:
  60. s = next(filter(sss.startswith, self.table[s0][s1]))
  61. except:
  62. pass
  63. tokens.append(self.token2idx[s])
  64. i += len(s)
  65. return tokens
  66. def decodeBytes(self, tokens):
  67. return b''.join(map(lambda i: self.idx2token[i], tokens))
  68. def encode(self, src: str):
  69. return self.encodeBytes(src.encode("utf-8"))
  70. def decode(self, tokens):
  71. return self.decodeBytes(tokens).decode('utf-8')
  72. def printTokens(self, tokens):
  73. for i in tokens:
  74. s = self.idx2token[i]
  75. try:
  76. s = s.decode('utf-8')
  77. except:
  78. pass
  79. print(f'{repr(s)}{i}', end=' ')
  80. # print(repr(s), i)
  81. print()
  82. ########################################################################################################
  83. # In[4]:
  84. def sample_logits(out, temperature=1.0, top_p=0.8):
  85. probs = F.softmax(out, dim=-1).numpy()
  86. sorted_probs = np.sort(probs)[::-1]
  87. cumulative_probs = np.cumsum(sorted_probs)
  88. cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
  89. probs[probs < cutoff] = 0
  90. if temperature != 1.0:
  91. probs = probs.pow(1.0 / temperature)
  92. probs = probs / np.sum(probs)
  93. out = np.random.choice(a=len(probs), p=probs)
  94. return out
  95. ########################################################################################################
  96. 可以从这个链接下载模型:
  97. https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files
  98. https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth
  99. # In[68]:
  100. tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")
  101. # THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS
  102. args = types.SimpleNamespace()
  103. args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth
  104. args.n_layer = 24
  105. args.n_embd = 1024
  106. args.vocab_size = 65536
  107. # In[69]:
  108. # N_LAYER="12"
  109. # N_EMBD="768"
  110. N_LAYER="24"
  111. N_EMBD="1024"
  112. # In[70]:
  113. # context = "\nElon Musk has"
  114. # context = "\n我们发现"
  115. context = "Q:Do you know datawhalechina?\nA:"
  116. NUM_TRIALS = 3
  117. LENGTH_PER_TRIAL = 100
  118. LENGTH_PER_TRIAL = 4096
  119. TEMPERATURE = 1.0
  120. TOP_P = 0.7
  121. # In[80]:
  122. class RWKV_RNN(MyModule):
  123. def __init__(self, args):
  124. super().__init__()
  125. self.args = args
  126. self.eval() # set torch to inference mode
  127. w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
  128. for k in w.keys():
  129. w[k] = w[k].float() # convert to f32 type
  130. if '.time_' in k: w[k] = w[k].squeeze()
  131. if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
  132. if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
  133. self.n_head = w['blocks.0.att.time_decay'].shape[0]
  134. self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
  135. self.w = types.SimpleNamespace() # set self.w from w
  136. self.w.blocks = {}
  137. for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
  138. parts = k.split('.')
  139. last = parts.pop()
  140. here = self.w
  141. for p in parts:
  142. if p.isdigit():
  143. p = int(p)
  144. if p not in here: here[p] = types.SimpleNamespace()
  145. here = here[p]
  146. else:
  147. if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
  148. here = getattr(here, p)
  149. setattr(here, last, w[k])
  150. def layer_norm(self, x, w):
  151. return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
  152. @MyFunction
  153. def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
  154. i0 = (2+self.head_size)*i+0
  155. xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
  156. xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
  157. state[i0] = x
  158. r = torch.sigmoid(rw @ xr)
  159. k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
  160. return r * (vw @ k)
  161. @MyFunction
  162. def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
  163. H = self.n_head
  164. S = self.head_size
  165. i1 = (2+S)*i+1
  166. xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
  167. xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
  168. xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
  169. xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
  170. state[i1] = x
  171. r = (rw @ xr).view(H, 1, S)
  172. k = (kw @ xk).view(H, S, 1)
  173. v = (vw @ xv).view(H, 1, S)
  174. g = F.silu(gw @ xg)
  175. s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
  176. x = torch.zeros(H, S)
  177. a = k @ v
  178. x = r @ (time_first * a + s)
  179. s = a + time_decay * s
  180. state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
  181. x = x.flatten()
  182. x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
  183. return ow @ x
  184. def forward(self, token, state):
  185. with torch.no_grad():
  186. if state == None:
  187. state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
  188. x = self.w.emb.weight[token]
  189. x = self.layer_norm(x, self.w.blocks[0].ln0)
  190. for i in range(self.args.n_layer):
  191. # print(i)
  192. att = self.w.blocks[i].att
  193. x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
  194. att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay,
  195. att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
  196. att.ln_x.weight, att.ln_x.bias)
  197. ffn = self.w.blocks[i].ffn
  198. x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
  199. ffn.time_mix_k, ffn.time_mix_r,
  200. ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
  201. x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
  202. return x.float(), state
  203. # In[81]:
  204. context = "Q:Do you know datawhalechina?\nA:"
  205. # In[82]:
  206. args.MODEL_NAME
  207. # In[83]:
  208. args.n_layer,args.n_embd
  209. # In[84]:
  210. # args.n_layer = 24
  211. # args.n_embd = 1024
  212. # In[85]:
  213. # args.n_layer = 12
  214. # args.n_embd = 768
  215. # In[86]:
  216. # args.MODEL_NAME='../models/rwkv-5-world-1b5'
  217. # In[87]:
  218. print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
  219. model = RWKV_RNN(args)
  220. print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
  221. init_state = None
  222. # In[88]:
  223. init_state = None
  224. # In[89]:
  225. LENGTH_PER_TRIAL=1024
  226. # In[90]:
  227. for token in tokenizer.encode(context):
  228. init_out, init_state = model.forward(token, init_state)
  229. for TRIAL in range(NUM_TRIALS):
  230. print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
  231. all_tokens = []
  232. out_last = 0
  233. out, state = init_out.clone(), init_state.clone()
  234. for i in range(LENGTH_PER_TRIAL):
  235. token = sample_logits(out, TEMPERATURE, TOP_P)
  236. all_tokens += [token]
  237. try:
  238. tmp = tokenizer.decode(all_tokens[out_last:])
  239. if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
  240. print(tmp, end="", flush=True)
  241. out_last = i + 1
  242. except:
  243. pass
  244. out, state = model.forward(token, state)
  245. print('\n')
  246. # 显然datawhale这个数据没有训练过哈哈。不过速度是蛮快的,这个没得说,在cpu上跑,资源消耗也很小。