123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- #!/usr/bin/env python
- # coding: utf-8
- 模型下载链接:https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true
- # In[1]:
- ########################################################################################################
- # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
- ########################################################################################################
- import numpy as np
- np.set_printoptions(precision=4, suppress=True, linewidth=200)
- import types, torch
- from torch.nn import functional as F
- from tokenizers import Tokenizer
- # In[2]:
- tokenizer = Tokenizer.from_file("20B_tokenizer.json")
- args = types.SimpleNamespace()
- args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
- args.n_layer = 24
- args.n_embd = 1024
- context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
- NUM_TRIALS = 3
- LENGTH_PER_TRIAL = 100
- TEMPERATURE = 1.0
- TOP_P = 0.85
- ########################################################################################################
- # In[3]:
- class RWKV_RNN(torch.jit.ScriptModule):
- def __init__(self, args):
- super().__init__()
- self.args = args
- self.eval() # set torch to inference mode
-
- w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
- for k in w.keys():
- if '.time_' in k: w[k] = w[k].squeeze()
- if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
- else: w[k] = w[k].float() # convert to f32 type
-
- self.w = types.SimpleNamespace() # set self.w from w
- self.w.blocks = {}
- for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
- parts = k.split('.')
- last = parts.pop()
- here = self.w
- for p in parts:
- if p.isdigit():
- p = int(p)
- if p not in here: here[p] = types.SimpleNamespace()
- here = here[p]
- else:
- if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
- here = getattr(here, p)
- setattr(here, last, w[k])
- def layer_norm(self, x, w):
- return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
- @torch.jit.script_method
- def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
- xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
- xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
- state[5*i+0] = x
- r = torch.sigmoid(rw @ xr)
- k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
- return r * (vw @ k)
- @torch.jit.script_method
- def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
- xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
- xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
- xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
- state[5*i+1] = x
- r = torch.sigmoid(rw @ xr)
- k = kw @ xk
- v = vw @ xv
-
- aa = state[5*i+2]
- bb = state[5*i+3]
- pp = state[5*i+4]
- ww = time_first + k
- qq = torch.maximum(pp, ww)
- e1 = torch.exp(pp - qq)
- e2 = torch.exp(ww - qq)
- a = e1 * aa + e2 * v
- b = e1 * bb + e2
- wkv = a / b
- ww = pp + time_decay
- qq = torch.maximum(ww, k)
- e1 = torch.exp(ww - qq)
- e2 = torch.exp(k - qq)
- state[5*i+2] = e1 * aa + e2 * v
- state[5*i+3] = e1 * bb + e2
- state[5*i+4] = qq
- return ow @ (r * wkv)
- def forward(self, token, state):
- with torch.no_grad():
- if state == None:
- state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
- for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
-
- x = self.w.emb.weight[token]
- x = self.layer_norm(x, self.w.blocks[0].ln0)
- for i in range(self.args.n_layer):
- att = self.w.blocks[i].att
- x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
- att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
- att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
- ffn = self.w.blocks[i].ffn
- x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
- ffn.time_mix_k, ffn.time_mix_r,
- ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
-
- x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
- return x.float(), state
- ##########################################################################################################
- # In[4]:
- def sample_logits(out, temperature=1.0, top_p=0.8):
- probs = F.softmax(out, dim=-1).numpy()
- sorted_probs = np.sort(probs)[::-1]
- cumulative_probs = np.cumsum(sorted_probs)
- cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
- probs[probs < cutoff] = 0
- if temperature != 1.0:
- probs = probs.pow(1.0 / temperature)
- probs = probs / np.sum(probs)
- out = np.random.choice(a=len(probs), p=probs)
- return out
- ########################################################################################################
- # In[6]:
- print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
- model = RWKV_RNN(args)
- # In[7]:
- print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
- init_state = None
- for token in tokenizer.encode(context).ids:
- init_out, init_state = model.forward(token, init_state)
- # In[8]:
- for TRIAL in range(NUM_TRIALS):
- print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
- all_tokens = []
- out_last = 0
- out, state = init_out.clone(), init_state.clone()
- for i in range(LENGTH_PER_TRIAL):
- token = sample_logits(out, TEMPERATURE, TOP_P)
- all_tokens += [token]
- tmp = tokenizer.decode(all_tokens[out_last:])
- if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
- print(tmp, end="", flush=True)
- out_last = i + 1
- out, state = model.forward(token, state)
- print('\n')
- # In[ ]:
|