model_v6.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[5]:
  4. ########################################################################################################
  5. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  6. ########################################################################################################
  7. import numpy as np
  8. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  9. import types, torch
  10. import torch.nn as nn
  11. from torch.nn import functional as F
  12. MyModule = torch.jit.ScriptModule
  13. MyFunction = torch.jit.script_method
  14. # In[6]:
  15. class RWKV_TOKENIZER():
  16. table: list[list[list[bytes]]]
  17. good: list[set[int]]
  18. wlen: list[int]
  19. def __init__(self, file_name):
  20. self.idx2token = {}
  21. sorted = [] # must be already sorted
  22. lines = open(file_name, "r", encoding="utf-8").readlines()
  23. for l in lines:
  24. idx = int(l[:l.index(' ')])
  25. x = eval(l[l.index(' '):l.rindex(' ')])
  26. x = x.encode("utf-8") if isinstance(x, str) else x
  27. assert isinstance(x, bytes)
  28. assert len(x) == int(l[l.rindex(' '):])
  29. sorted += [x]
  30. self.idx2token[idx] = x
  31. self.token2idx = {}
  32. for k, v in self.idx2token.items():
  33. self.token2idx[v] = int(k)
  34. # precompute some tables for fast matching
  35. self.table = [[[] for j in range(256)] for i in range(256)]
  36. self.good = [set() for i in range(256)]
  37. self.wlen = [0 for i in range(256)]
  38. for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
  39. s = sorted[i]
  40. if len(s) >= 2:
  41. s0 = int(s[0])
  42. s1 = int(s[1])
  43. self.table[s0][s1] += [s]
  44. self.wlen[s0] = max(self.wlen[s0], len(s))
  45. self.good[s0].add(s1)
  46. def encodeBytes(self, src: bytes) -> list[int]:
  47. src_len: int = len(src)
  48. tokens: list[int] = []
  49. i: int = 0
  50. while i < src_len:
  51. s: bytes = src[i : i + 1]
  52. if i < src_len - 1:
  53. s1: int = int(src[i + 1])
  54. s0: int = int(src[i])
  55. if s1 in self.good[s0]:
  56. sss: bytes = src[i : i + self.wlen[s0]]
  57. try:
  58. s = next(filter(sss.startswith, self.table[s0][s1]))
  59. except:
  60. pass
  61. tokens.append(self.token2idx[s])
  62. i += len(s)
  63. return tokens
  64. def decodeBytes(self, tokens):
  65. return b''.join(map(lambda i: self.idx2token[i], tokens))
  66. def encode(self, src: str):
  67. return self.encodeBytes(src.encode("utf-8"))
  68. def decode(self, tokens):
  69. return self.decodeBytes(tokens).decode('utf-8')
  70. def printTokens(self, tokens):
  71. for i in tokens:
  72. s = self.idx2token[i]
  73. try:
  74. s = s.decode('utf-8')
  75. except:
  76. pass
  77. print(f'{repr(s)}{i}', end=' ')
  78. # print(repr(s), i)
  79. print()
  80. ########################################################################################################
  81. # In[7]:
  82. def sample_logits(out, temperature=1.0, top_p=0.8):
  83. probs = F.softmax(out, dim=-1).numpy()
  84. sorted_probs = np.sort(probs)[::-1]
  85. cumulative_probs = np.cumsum(sorted_probs)
  86. cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
  87. probs[probs < cutoff] = 0
  88. if temperature != 1.0:
  89. probs = probs.pow(1.0 / temperature)
  90. probs = probs / np.sum(probs)
  91. out = np.random.choice(a=len(probs), p=probs)
  92. return out
  93. ########################################################################################################
  94. 模型下载地址:https://hf-mirror.com/BlinkDL/rwkv-6-world/resolve/main/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth
  95. # In[13]:
  96. tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")
  97. args = types.SimpleNamespace()
  98. args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'
  99. args.n_layer = 24
  100. args.n_embd = 2048
  101. args.vocab_size = 65536
  102. context = "\nDatawhale is "
  103. # context = "\n我们发现"
  104. NUM_TRIALS = 3
  105. LENGTH_PER_TRIAL = 100
  106. TEMPERATURE = 1.0
  107. TOP_P = 0.7
  108. # In[14]:
  109. class RWKV_RNN(MyModule):
  110. def __init__(self, args):
  111. super().__init__()
  112. self.args = args
  113. self.eval() # set torch to inference mode
  114. w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
  115. for k in w.keys():
  116. w[k] = w[k].float() # convert to f32 type
  117. if '.time_' in k: w[k] = w[k].squeeze()
  118. if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
  119. self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
  120. self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
  121. self.w = types.SimpleNamespace() # set self.w from w
  122. self.w.blocks = {}
  123. for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
  124. parts = k.split('.')
  125. last = parts.pop()
  126. here = self.w
  127. for p in parts:
  128. if p.isdigit():
  129. p = int(p)
  130. if p not in here: here[p] = types.SimpleNamespace()
  131. here = here[p]
  132. else:
  133. if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
  134. here = getattr(here, p)
  135. setattr(here, last, w[k])
  136. def layer_norm(self, x, w):
  137. return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
  138. @MyFunction
  139. def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):
  140. i0 = (2+self.head_size)*i+0
  141. sx = state[i0] - x
  142. xk = x + sx * time_maa_k
  143. xr = x + sx * time_maa_r
  144. state[i0] = x
  145. r = torch.sigmoid(rw @ xr)
  146. k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
  147. return r * (vw @ k)
  148. @MyFunction
  149. def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
  150. H = self.n_head
  151. S = self.head_size
  152. i1 = (2+S)*i+1
  153. sx = state[i1] - x
  154. state[i1] = x
  155. xxx = x + sx * x_maa
  156. xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
  157. xxx = torch.bmm(xxx, tm_w2).view(5, -1)
  158. mw, mk, mv, mr, mg = xxx.unbind(dim=0)
  159. xw = x + sx * (w_maa + mw)
  160. xk = x + sx * (k_maa + mk)
  161. xv = x + sx * (v_maa + mv)
  162. xr = x + sx * (r_maa + mr)
  163. xg = x + sx * (g_maa + mg)
  164. w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
  165. w = torch.exp(-torch.exp(w.float()))
  166. r = (rw @ xr).view(H, 1, S)
  167. k = (kw @ xk).view(H, S, 1)
  168. v = (vw @ xv).view(H, 1, S)
  169. g = F.silu(gw @ xg)
  170. s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)
  171. x = torch.zeros(H, S)
  172. a = k @ v
  173. x = r @ (time_first * a + s)
  174. s = a + w * s
  175. state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
  176. x = x.flatten()
  177. x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
  178. return ow @ x
  179. def forward(self, token, state):
  180. with torch.no_grad():
  181. if state == None:
  182. state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
  183. x = self.w.emb.weight[token]
  184. x = self.layer_norm(x, self.w.blocks[0].ln0)
  185. for i in range(self.args.n_layer):
  186. att = self.w.blocks[i].att
  187. x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
  188. att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
  189. att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
  190. att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
  191. att.ln_x.weight, att.ln_x.bias)
  192. ffn = self.w.blocks[i].ffn
  193. x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
  194. ffn.time_maa_k, ffn.time_maa_r,
  195. ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
  196. x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
  197. return x.float(), state
  198. # In[15]:
  199. print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
  200. model = RWKV_RNN(args)
  201. print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
  202. init_state = None
  203. # In[16]:
  204. for token in tokenizer.encode(context):
  205. init_out, init_state = model.forward(token, init_state)
  206. # In[17]:
  207. for TRIAL in range(NUM_TRIALS):
  208. print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
  209. all_tokens = []
  210. out_last = 0
  211. out, state = init_out.clone(), init_state.clone()
  212. for i in range(LENGTH_PER_TRIAL):
  213. token = sample_logits(out, TEMPERATURE, TOP_P)
  214. all_tokens += [token]
  215. try:
  216. tmp = tokenizer.decode(all_tokens[out_last:])
  217. if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
  218. print(tmp, end="", flush=True)
  219. out_last = i + 1
  220. except:
  221. pass
  222. out, state = model.forward(token, state)
  223. print('\n')
  224. # v6和v5相比,感觉更喜欢使用emoj了哈哈