model_run.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import types
  5. import copy
  6. import torch
  7. import math
  8. from torch.nn import functional as F
  9. import torch.nn as nn
  10. RWKV_K_CLAMP = 60
  11. RWKV_K_EPS = 1e-8
  12. RWKV_HEAD_QK_DIM = 256
  13. print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
  14. DEBUG_TIME = False # True False - show trained time-coeffs
  15. ############################################################################################################
  16. RWKV_CFG = types.SimpleNamespace()
  17. class RWKV_ChannelMix(nn.Module):
  18. def __init__(self, layer_id):
  19. super().__init__()
  20. self.layer_id = layer_id
  21. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  22. self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
  23. self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
  24. hidden_sz = 4 * RWKV_CFG.n_embd
  25. self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
  26. self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
  27. self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
  28. def forward(self, x):
  29. xx = self.time_shift(x)
  30. xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
  31. xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
  32. k = self.key(xk)
  33. k = torch.square(torch.relu(k))
  34. kv = self.value(k)
  35. rkv = torch.sigmoid(self.receptance(xr)) * kv
  36. return rkv
  37. class RWKV_TimeMix(nn.Module):
  38. def __init__(self, layer_id):
  39. super().__init__()
  40. self.layer_id = layer_id
  41. self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
  42. self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
  43. self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
  44. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  45. self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
  46. self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
  47. self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
  48. self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
  49. self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
  50. self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
  51. self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
  52. def forward(self, x):
  53. B, T, C = x.size()
  54. xx = self.time_shift(x)
  55. xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
  56. xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
  57. xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
  58. k = self.key(xk).transpose(-1, -2)
  59. v = self.value(xv).transpose(-1, -2)
  60. r = self.receptance(xr)
  61. k = torch.clamp(k, max=RWKV_K_CLAMP)
  62. k = torch.exp(k)
  63. kv = k * v
  64. self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
  65. w = torch.exp(self.time_w)
  66. w = w[:,-T:].unsqueeze(1)
  67. wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
  68. wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
  69. rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
  70. rwkv = self.output(rwkv)
  71. return rwkv
  72. class Block(nn.Module):
  73. def __init__(self, layer_id):
  74. super().__init__()
  75. self.layer_id = layer_id
  76. self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
  77. self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
  78. if self.layer_id == 0:
  79. self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
  80. if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
  81. self.ffnPre = RWKV_ChannelMix(layer_id+1000)
  82. else:
  83. self.att = RWKV_TimeMix(layer_id)
  84. self.ffn = RWKV_ChannelMix(layer_id)
  85. def forward(self, x):
  86. if self.layer_id == 0:
  87. x = self.ln0(x)
  88. if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
  89. x = x + self.ffnPre(self.ln1(x))
  90. else:
  91. x = x + self.att(self.ln1(x))
  92. x = x + self.ffn(self.ln2(x))
  93. return x
  94. class RWKV_GPT(nn.Module):
  95. def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
  96. global RWKV_CFG
  97. super().__init__()
  98. RWKV_CFG.RUN_DEVICE = RUN_DEVICE
  99. RWKV_CFG.model_type = model_type
  100. RWKV_CFG.vocab_size = vocab_size
  101. RWKV_CFG.n_layer = n_layer
  102. RWKV_CFG.n_embd = n_embd
  103. RWKV_CFG.ctx_len = ctx_len
  104. print('\nloading RWKV-GPT', MODEL_NAME)
  105. self.emb = nn.Embedding(vocab_size, n_embd)
  106. self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
  107. self.ln_out = nn.LayerNorm(n_embd)
  108. self.head = nn.Linear(n_embd, vocab_size, bias=False)
  109. if RWKV_HEAD_QK_DIM > 0:
  110. self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
  111. self.head_q.scale_init = 0
  112. self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
  113. self.head_k.scale_init = 0.1
  114. self.register_buffer("copy_mask", torch.tril(
  115. torch.ones(ctx_len, ctx_len)))
  116. self.ctx_len = ctx_len
  117. self.eval()
  118. self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
  119. self.eval()
  120. def forward(self, idx):
  121. B, T = idx.size()
  122. assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
  123. x = self.emb(idx)
  124. x = self.blocks(x)
  125. x = self.ln_out(x)
  126. if RWKV_HEAD_QK_DIM > 0:
  127. q = self.head_q(x)[:, :T, :]
  128. k = self.head_k(x)[:, :T, :]
  129. c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
  130. c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
  131. c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
  132. x = self.head(x) + c
  133. else:
  134. x = self.head(x)
  135. return x
  136. ############################################################################################################
  137. class RWKV_RNN():
  138. def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
  139. self.RUN_DEVICE = RUN_DEVICE
  140. self.model_type = model_type
  141. self.n_layer = n_layer
  142. self.n_embd = n_embd
  143. self.ctx_len = ctx_len
  144. self.w = types.SimpleNamespace()
  145. w = torch.load(MODEL_NAME + '.pth',
  146. map_location=torch.device(RUN_DEVICE))
  147. for x in w.keys():
  148. if '.time_' in x:
  149. w[x] = w[x].squeeze()
  150. if '.time_decay' in x:
  151. w[x] = torch.exp(-torch.exp(w[x]))
  152. if '.time_first' in x:
  153. w[x] = torch.exp(w[x])
  154. if DEBUG_TIME and '.time_' in x:
  155. print(x, w[x].squeeze().cpu().numpy())
  156. xx = x.split('.')
  157. here = self.w
  158. for i in range(len(xx)):
  159. if xx[i].isdigit():
  160. ii = int(xx[i])
  161. if ii not in here:
  162. here[ii] = types.SimpleNamespace()
  163. here = here[ii]
  164. else:
  165. if i == len(xx) - 1:
  166. setattr(here, xx[i], w[x])
  167. elif not hasattr(here, xx[i]):
  168. if xx[i+1].isdigit():
  169. setattr(here, xx[i], {})
  170. else:
  171. setattr(here, xx[i], types.SimpleNamespace())
  172. here = getattr(here, xx[i])
  173. self.clear()
  174. def clear(self):
  175. self.xx = {}
  176. self.aa = {}
  177. self.bb = {}
  178. self.hk = None
  179. def save(self, target):
  180. target.xx = copy.deepcopy(self.xx)
  181. target.aa = copy.deepcopy(self.aa)
  182. target.bb = copy.deepcopy(self.bb)
  183. target.hk = copy.deepcopy(self.hk)
  184. def load(self, target):
  185. self.xx = copy.deepcopy(target.xx)
  186. self.aa = copy.deepcopy(target.aa)
  187. self.bb = copy.deepcopy(target.bb)
  188. self.hk = copy.deepcopy(target.hk)
  189. def LN(self, xx, w):
  190. return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
  191. def FF(self, xx, w, name):
  192. if name not in self.xx:
  193. self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
  194. xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
  195. xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
  196. self.xx[name] = xx
  197. r = torch.sigmoid(w.receptance.weight @ xr)
  198. k = torch.square(torch.relu(w.key.weight @ xk))
  199. kv = w.value.weight @ k
  200. return r * kv
  201. def SA(self, xx, w, name):
  202. if name not in self.xx:
  203. self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
  204. self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
  205. self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
  206. xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
  207. xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
  208. xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
  209. self.xx[name] = xx
  210. r = torch.sigmoid(w.receptance.weight @ xr)
  211. k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
  212. v = w.value.weight @ xv
  213. kv = k * v
  214. a = self.aa[name] + w.time_first * kv
  215. b = self.bb[name] + w.time_first * k
  216. self.aa[name] = w.time_decay * self.aa[name] + kv
  217. self.bb[name] = w.time_decay * self.bb[name] + k
  218. rwkv = r * a / (b + RWKV_K_EPS)
  219. return w.output.weight @ rwkv
  220. def run(self, ctx):
  221. w = self.w
  222. x = w.emb.weight[ctx[-1]]
  223. for i in range(self.n_layer):
  224. if i == 0:
  225. x = self.LN(x, w.blocks[i].ln0)
  226. if i == 0 and self.model_type == 'RWKV-ffnPre':
  227. x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
  228. else:
  229. x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
  230. x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
  231. x = self.LN(x, w.ln_out)
  232. if RWKV_HEAD_QK_DIM > 0:
  233. if self.hk == None:
  234. self.hk = (w.head_k.weight @ x).unsqueeze(0)
  235. else:
  236. self.hk = torch.cat(
  237. [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
  238. if self.hk.shape[0] > self.ctx_len:
  239. self.hk = self.hk[-self.ctx_len:, :]
  240. q = w.head_q.weight @ x
  241. x = w.head.weight @ x
  242. x = x.cpu().numpy().tolist()
  243. c = (self.hk @ q) / RWKV_HEAD_QK_DIM
  244. for i in range(len(c)):
  245. x[ctx[i]] += c[i]
  246. else:
  247. x = w.head.weight @ x
  248. x = x.cpu().numpy().tolist()
  249. return x