model_v4.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. 模型下载链接:https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true
  4. # In[1]:
  5. ########################################################################################################
  6. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  7. ########################################################################################################
  8. import numpy as np
  9. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  10. import types, torch
  11. from torch.nn import functional as F
  12. from tokenizers import Tokenizer
  13. # In[2]:
  14. tokenizer = Tokenizer.from_file("20B_tokenizer.json")
  15. args = types.SimpleNamespace()
  16. args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
  17. args.n_layer = 24
  18. args.n_embd = 1024
  19. context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
  20. NUM_TRIALS = 3
  21. LENGTH_PER_TRIAL = 100
  22. TEMPERATURE = 1.0
  23. TOP_P = 0.85
  24. ########################################################################################################
  25. # In[3]:
  26. class RWKV_RNN(torch.jit.ScriptModule):
  27. def __init__(self, args):
  28. super().__init__()
  29. self.args = args
  30. self.eval() # set torch to inference mode
  31. w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
  32. for k in w.keys():
  33. if '.time_' in k: w[k] = w[k].squeeze()
  34. if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
  35. else: w[k] = w[k].float() # convert to f32 type
  36. self.w = types.SimpleNamespace() # set self.w from w
  37. self.w.blocks = {}
  38. for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
  39. parts = k.split('.')
  40. last = parts.pop()
  41. here = self.w
  42. for p in parts:
  43. if p.isdigit():
  44. p = int(p)
  45. if p not in here: here[p] = types.SimpleNamespace()
  46. here = here[p]
  47. else:
  48. if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
  49. here = getattr(here, p)
  50. setattr(here, last, w[k])
  51. def layer_norm(self, x, w):
  52. return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
  53. @torch.jit.script_method
  54. def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
  55. xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
  56. xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
  57. state[5*i+0] = x
  58. r = torch.sigmoid(rw @ xr)
  59. k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
  60. return r * (vw @ k)
  61. @torch.jit.script_method
  62. def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
  63. xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
  64. xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
  65. xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
  66. state[5*i+1] = x
  67. r = torch.sigmoid(rw @ xr)
  68. k = kw @ xk
  69. v = vw @ xv
  70. aa = state[5*i+2]
  71. bb = state[5*i+3]
  72. pp = state[5*i+4]
  73. ww = time_first + k
  74. qq = torch.maximum(pp, ww)
  75. e1 = torch.exp(pp - qq)
  76. e2 = torch.exp(ww - qq)
  77. a = e1 * aa + e2 * v
  78. b = e1 * bb + e2
  79. wkv = a / b
  80. ww = pp + time_decay
  81. qq = torch.maximum(ww, k)
  82. e1 = torch.exp(ww - qq)
  83. e2 = torch.exp(k - qq)
  84. state[5*i+2] = e1 * aa + e2 * v
  85. state[5*i+3] = e1 * bb + e2
  86. state[5*i+4] = qq
  87. return ow @ (r * wkv)
  88. def forward(self, token, state):
  89. with torch.no_grad():
  90. if state == None:
  91. state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
  92. for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
  93. x = self.w.emb.weight[token]
  94. x = self.layer_norm(x, self.w.blocks[0].ln0)
  95. for i in range(self.args.n_layer):
  96. att = self.w.blocks[i].att
  97. x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
  98. att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
  99. att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
  100. ffn = self.w.blocks[i].ffn
  101. x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
  102. ffn.time_mix_k, ffn.time_mix_r,
  103. ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
  104. x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
  105. return x.float(), state
  106. ##########################################################################################################
  107. # In[4]:
  108. def sample_logits(out, temperature=1.0, top_p=0.8):
  109. probs = F.softmax(out, dim=-1).numpy()
  110. sorted_probs = np.sort(probs)[::-1]
  111. cumulative_probs = np.cumsum(sorted_probs)
  112. cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
  113. probs[probs < cutoff] = 0
  114. if temperature != 1.0:
  115. probs = probs.pow(1.0 / temperature)
  116. probs = probs / np.sum(probs)
  117. out = np.random.choice(a=len(probs), p=probs)
  118. return out
  119. ########################################################################################################
  120. # In[6]:
  121. print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
  122. model = RWKV_RNN(args)
  123. # In[7]:
  124. print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
  125. init_state = None
  126. for token in tokenizer.encode(context).ids:
  127. init_out, init_state = model.forward(token, init_state)
  128. # In[8]:
  129. for TRIAL in range(NUM_TRIALS):
  130. print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
  131. all_tokens = []
  132. out_last = 0
  133. out, state = init_out.clone(), init_state.clone()
  134. for i in range(LENGTH_PER_TRIAL):
  135. token = sample_logits(out, TEMPERATURE, TOP_P)
  136. all_tokens += [token]
  137. tmp = tokenizer.decode(all_tokens[out_last:])
  138. if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
  139. print(tmp, end="", flush=True)
  140. out_last = i + 1
  141. out, state = model.forward(token, state)
  142. print('\n')
  143. # In[ ]: