#!/usr/bin/env python # coding: utf-8 模型下载链接:https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true # In[1]: ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np np.set_printoptions(precision=4, suppress=True, linewidth=200) import types, torch from torch.nn import functional as F from tokenizers import Tokenizer # In[2]: tokenizer = Tokenizer.from_file("20B_tokenizer.json") args = types.SimpleNamespace() args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066' args.n_layer = 24 args.n_embd = 1024 context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence." NUM_TRIALS = 3 LENGTH_PER_TRIAL = 100 TEMPERATURE = 1.0 TOP_P = 0.85 ######################################################################################################## # In[3]: class RWKV_RNN(torch.jit.ScriptModule): def __init__(self, args): super().__init__() self.args = args self.eval() # set torch to inference mode w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') for k in w.keys(): if '.time_' in k: w[k] = w[k].squeeze() if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x} else: w[k] = w[k].float() # convert to f32 type self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {} for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first parts = k.split('.') last = parts.pop() here = self.w for p in parts: if p.isdigit(): p = int(p) if p not in here: here[p] = types.SimpleNamespace() here = here[p] else: if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) def layer_norm(self, x, w): return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) @torch.jit.script_method def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) state[5*i+0] = x r = torch.sigmoid(rw @ xr) k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper return r * (vw @ k) @torch.jit.script_method def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) state[5*i+1] = x r = torch.sigmoid(rw @ xr) k = kw @ xk v = vw @ xv aa = state[5*i+2] bb = state[5*i+3] pp = state[5*i+4] ww = time_first + k qq = torch.maximum(pp, ww) e1 = torch.exp(pp - qq) e2 = torch.exp(ww - qq) a = e1 * aa + e2 * v b = e1 * bb + e2 wkv = a / b ww = pp + time_decay qq = torch.maximum(ww, k) e1 = torch.exp(ww - qq) e2 = torch.exp(k - qq) state[5*i+2] = e1 * aa + e2 * v state[5*i+3] = e1 * bb + e2 state[5*i+4] = qq return ow @ (r * wkv) def forward(self, token, state): with torch.no_grad(): if state == None: state = torch.zeros(self.args.n_layer * 5, self.args.n_embd) for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity x = self.w.emb.weight[token] x = self.layer_norm(x, self.w.blocks[0].ln0) for i in range(self.args.n_layer): att = self.w.blocks[i].att x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, att.key.weight, att.value.weight, att.receptance.weight, att.output.weight) ffn = self.w.blocks[i].ffn x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, ffn.time_mix_k, ffn.time_mix_r, ffn.key.weight, ffn.value.weight, ffn.receptance.weight) x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out) return x.float(), state ########################################################################################################## # In[4]: def sample_logits(out, temperature=1.0, top_p=0.8): probs = F.softmax(out, dim=-1).numpy() sorted_probs = np.sort(probs)[::-1] cumulative_probs = np.cumsum(sorted_probs) cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) probs[probs < cutoff] = 0 if temperature != 1.0: probs = probs.pow(1.0 / temperature) probs = probs / np.sum(probs) out = np.random.choice(a=len(probs), p=probs) return out ######################################################################################################## # In[6]: print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...') model = RWKV_RNN(args) # In[7]: print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') init_state = None for token in tokenizer.encode(context).ids: init_out, init_state = model.forward(token, init_state) # In[8]: for TRIAL in range(NUM_TRIALS): print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="") all_tokens = [] out_last = 0 out, state = init_out.clone(), init_state.clone() for i in range(LENGTH_PER_TRIAL): token = sample_logits(out, TEMPERATURE, TOP_P) all_tokens += [token] tmp = tokenizer.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: # only print when we have a valid utf-8 string print(tmp, end="", flush=True) out_last = i + 1 out, state = model.forward(token, state) print('\n') # In[ ]: