model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import math
  5. import logging
  6. import torch
  7. import torch.nn as nn
  8. from torch.nn import functional as F
  9. logger = logging.getLogger(__name__)
  10. ########################################################################################################
  11. # RWKV: RWKV Time-mix + RWKV Channel-mix
  12. ########################################################################################################
  13. def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
  14. for m in module.modules():
  15. if not isinstance(m, (nn.Linear, nn.Embedding)):
  16. continue
  17. with torch.no_grad():
  18. name = '[unknown weight]'
  19. for name, parameter in module.named_parameters(): # find the name of the weight
  20. if id(m.weight) == id(parameter):
  21. break
  22. shape = m.weight.data.shape
  23. gain = 1.0 # positive: gain for orthogonal, negative: std for normal
  24. scale = 1.0 # extra scale for gain
  25. if isinstance(m, nn.Linear):
  26. if m.bias is not None:
  27. m.bias.data.zero_()
  28. if shape[0] > shape[1]:
  29. gain = math.sqrt(shape[0] / shape[1])
  30. if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
  31. scale = config.rwkv_emb_scale
  32. if isinstance(m, nn.Embedding):
  33. gain = math.sqrt(max(shape[0], shape[1]))
  34. if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
  35. scale = config.rwkv_emb_scale
  36. if hasattr(m, 'scale_init'):
  37. scale = m.scale_init
  38. print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
  39. gain *= scale
  40. if gain == 0:
  41. nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
  42. elif gain > 0:
  43. nn.init.orthogonal_(m.weight, gain=gain)
  44. else:
  45. nn.init.normal_(m.weight, mean=0, std=-gain)
  46. class RWKV_TimeMix(nn.Module):
  47. def __init__(self, config, layer_id):
  48. super().__init__()
  49. assert config.n_attn % config.n_head == 0
  50. self.layer_id = layer_id
  51. self.ctx_len = config.ctx_len
  52. self.n_head = config.n_head
  53. self.head_size = config.n_attn // config.n_head
  54. with torch.no_grad(): # initial time_w curves for better convergence
  55. ww = torch.ones(config.n_head, config.ctx_len)
  56. curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
  57. for h in range(config.n_head):
  58. if h < config.n_head - 1:
  59. decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
  60. else:
  61. decay_speed = 0.0
  62. ww[h] = torch.exp(curve * decay_speed)
  63. # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
  64. self.time_w = nn.Parameter(ww)
  65. self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
  66. self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
  67. self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
  68. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  69. self.key = nn.Linear(config.n_embd, config.n_attn)
  70. self.value = nn.Linear(config.n_embd, config.n_attn)
  71. self.receptance = nn.Linear(config.n_embd, config.n_attn)
  72. # if config.rwkv_tiny_attn > 0:
  73. # self.tiny_att = RWKV_TinyAttn(config)
  74. self.output = nn.Linear(config.n_attn, config.n_embd)
  75. self.key.scale_init = 0
  76. self.receptance.scale_init = 0
  77. self.output.scale_init = 0
  78. def forward(self, x):
  79. B, T, C = x.size()
  80. TT = self.ctx_len
  81. w = F.pad(self.time_w, (0, TT))
  82. w = torch.tile(w, [TT])
  83. w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
  84. w = w[:, :, TT-1:] # w is now a circulant matrix
  85. w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
  86. x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
  87. # if hasattr(self, 'tiny_att'):
  88. # tiny_att = self.tiny_att(x, self.mask)
  89. k = self.key(x)
  90. v = self.value(x)
  91. r = self.receptance(x)
  92. k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
  93. k = torch.exp(k)
  94. sum_k = torch.cumsum(k, dim=1)
  95. kv = (k * v).view(B, T, self.n_head, self.head_size)
  96. wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
  97. rwkv = torch.sigmoid(r) * wkv / sum_k
  98. rwkv = self.output(rwkv)
  99. # if hasattr(self, 'tiny_att'):
  100. # rwkv += tiny_att
  101. return rwkv * self.time_gamma[:T, :]
  102. class RWKV_ChannelMix(nn.Module):
  103. def __init__(self, config, layer_id):
  104. super().__init__()
  105. self.layer_id = layer_id
  106. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  107. hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
  108. self.key = nn.Linear(config.n_embd, hidden_sz)
  109. self.value = nn.Linear(config.n_embd, hidden_sz)
  110. self.weight = nn.Linear(hidden_sz, config.n_embd)
  111. self.receptance = nn.Linear(config.n_embd, config.n_embd)
  112. self.receptance.scale_init = 0
  113. self.weight.scale_init = 0
  114. def forward(self, x):
  115. B, T, C = x.size()
  116. x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
  117. k = self.key(x)
  118. v = self.value(x)
  119. r = self.receptance(x)
  120. wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu
  121. rwkv = torch.sigmoid(r) * wkv
  122. return rwkv
  123. class RWKV_TinyAttn(nn.Module): # extra tiny attention
  124. def __init__(self, config):
  125. super().__init__()
  126. self.d_attn = config.rwkv_tiny_attn
  127. self.n_head = config.rwkv_tiny_head
  128. self.head_size = self.d_attn // self.n_head
  129. self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
  130. self.out = nn.Linear(self.d_attn, config.n_embd)
  131. def forward(self, x, mask):
  132. B, T, C = x.size()
  133. qkv = self.qkv(x)
  134. q, k, v = qkv.chunk(3, dim = -1)
  135. if self.n_head > 1:
  136. q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  137. k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  138. v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  139. qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
  140. qk = qk.masked_fill(mask == 0, float('-inf'))
  141. qk = F.softmax(qk, dim = -1)
  142. qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
  143. if self.n_head > 1:
  144. qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
  145. return self.out(qkv)
  146. ########################################################################################################
  147. # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
  148. ########################################################################################################
  149. class RotaryEmbedding(torch.nn.Module):
  150. def __init__(self, dim, base=10000):
  151. super().__init__()
  152. inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
  153. self.register_buffer('inv_freq', inv_freq)
  154. self.seq_len_cached = None
  155. self.cos_cached = None
  156. self.sin_cached = None
  157. def forward(self, x, seq_len=None):
  158. if seq_len != self.seq_len_cached:
  159. self.seq_len_cached = seq_len
  160. t = torch.arange(seq_len, device=x.device)
  161. freqs = torch.einsum('i,j->ij', t, self.inv_freq)
  162. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  163. self.cos_cached = emb.cos()
  164. self.sin_cached = emb.sin()
  165. return self.cos_cached, self.sin_cached
  166. def rotate_half(x):
  167. x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
  168. return torch.cat((-x2, x1), -1)
  169. @torch.jit.script
  170. def apply_rotary_pos_emb(q, k, cos, sin):
  171. cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
  172. return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
  173. class MHA_rotary(nn.Module):
  174. def __init__(self, config, layer_id, time_shift = False):
  175. super().__init__()
  176. self.layer_id = layer_id
  177. assert config.n_attn % config.n_head == 0
  178. self.n_head = config.n_head
  179. self.ctx_len = config.ctx_len
  180. self.head_size = config.n_attn // config.n_head
  181. if time_shift:
  182. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  183. self.query = nn.Linear(config.n_embd, config.n_attn)
  184. self.key = nn.Linear(config.n_embd, config.n_attn)
  185. self.value = nn.Linear(config.n_embd, config.n_attn)
  186. self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
  187. self.rotary_ndims = int(self.head_size * 0.5)
  188. self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
  189. self.output = nn.Linear(config.n_attn, config.n_embd)
  190. def forward(self, x):
  191. B, T, C = x.size()
  192. if hasattr(self, 'time_shift'):
  193. x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
  194. q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  195. k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  196. v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  197. q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
  198. k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
  199. cos, sin = self.rotary_emb(q, seq_len=T)
  200. q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
  201. q = torch.cat((q, query_pass), dim=-1)
  202. k = torch.cat((k, key_pass), dim=-1)
  203. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
  204. att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
  205. att = F.softmax(att, dim = -1) # softmax
  206. x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
  207. x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
  208. x = self.output(x)
  209. return x
  210. class GeGLU(torch.nn.Module):
  211. def __init__(self, config, layer_id, time_shift = False):
  212. super().__init__()
  213. self.layer_id = layer_id
  214. if time_shift:
  215. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  216. hidden_sz = 3 * config.n_ffn
  217. self.key = nn.Linear(config.n_embd, hidden_sz)
  218. self.value = nn.Linear(config.n_embd, hidden_sz)
  219. self.weight = nn.Linear(hidden_sz, config.n_embd)
  220. def forward(self, x):
  221. B, T, C = x.size()
  222. if hasattr(self, 'time_shift'):
  223. x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
  224. k = self.key(x)
  225. v = self.value(x)
  226. y = self.weight(F.gelu(k) * v)
  227. return y
  228. ########################################################################################################
  229. # MHA_pro: with more tricks
  230. ########################################################################################################
  231. class MHA_pro(nn.Module):
  232. def __init__(self, config, layer_id):
  233. super().__init__()
  234. self.layer_id = layer_id
  235. assert config.n_attn % config.n_head == 0
  236. self.n_head = config.n_head
  237. self.ctx_len = config.ctx_len
  238. self.head_size = config.n_attn // config.n_head
  239. self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
  240. self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
  241. self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
  242. self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
  243. self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
  244. self.time_shift = nn.ZeroPad2d((0,0,1,-1))
  245. self.query = nn.Linear(config.n_embd, config.n_attn)
  246. self.key = nn.Linear(config.n_embd, config.n_attn)
  247. self.value = nn.Linear(config.n_embd, config.n_attn)
  248. self.rotary_ndims = int(self.head_size * 0.5)
  249. self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
  250. self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
  251. self.output = nn.Linear(config.n_attn, config.n_embd)
  252. def forward(self, x):
  253. B, T, C = x.size()
  254. TT = self.ctx_len
  255. w = F.pad(self.time_w, (0, TT))
  256. w = torch.tile(w, [TT])
  257. w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
  258. w = w[:, :, TT-1:] # w is now a circulant matrix
  259. w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
  260. x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
  261. q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  262. k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  263. v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
  264. q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
  265. k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
  266. cos, sin = self.rotary_emb(q, seq_len=T)
  267. q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
  268. q = torch.cat((q, query_pass), dim=-1)
  269. k = torch.cat((k, key_pass), dim=-1)
  270. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
  271. att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
  272. att = F.softmax(att, dim = -1) # softmax
  273. att = att * w # time-weighting
  274. att = self.head_mix(att) # talking heads
  275. x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
  276. x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
  277. x = self.output(x) * self.time_gamma[:T, :]
  278. return x
  279. ########################################################################################################
  280. # The GPT Model with our blocks
  281. ########################################################################################################
  282. class RMSNorm(nn.Module):
  283. def __init__(self, d):
  284. super().__init__()
  285. self.dd = d ** (-1. / 2)
  286. self.weight = nn.Parameter(torch.ones(d))
  287. def forward(self, x):
  288. norm_x = x.norm(2, dim=-1, keepdim=True)
  289. x_normed = x / (norm_x * self.dd + 1e-12)
  290. return self.weight * x_normed
  291. class FixedNorm(nn.Module):
  292. def __init__(self, d):
  293. super().__init__()
  294. self.dd = d ** (-1. / 2)
  295. def forward(self, x):
  296. norm_x = x.norm(2, dim=-1, keepdim=True)
  297. x_normed = x / (norm_x * self.dd + 1e-12)
  298. return x_normed
  299. ########################################################################################################
  300. class GPTConfig:
  301. def __init__(self, vocab_size, ctx_len, **kwargs):
  302. self.vocab_size = vocab_size
  303. self.ctx_len = ctx_len
  304. for k,v in kwargs.items():
  305. setattr(self, k, v)
  306. class Block(nn.Module):
  307. def __init__(self, config, layer_id):
  308. super().__init__()
  309. self.config = config
  310. self.ln1 = nn.LayerNorm(config.n_embd)
  311. self.ln2 = nn.LayerNorm(config.n_embd)
  312. if config.model_type == 'RWKV':
  313. # self.ln1 = FixedNorm(config.n_embd)
  314. # self.ln2 = FixedNorm(config.n_embd)
  315. self.attn = RWKV_TimeMix(config, layer_id)
  316. self.mlp = RWKV_ChannelMix(config, layer_id)
  317. elif config.model_type == 'MHA_rotary':
  318. self.attn = MHA_rotary(config, layer_id)
  319. self.mlp = GeGLU(config, layer_id)
  320. elif config.model_type == 'MHA_shift':
  321. self.attn = MHA_rotary(config, layer_id, time_shift=True)
  322. self.mlp = GeGLU(config, layer_id, time_shift=True)
  323. elif config.model_type == 'MHA_pro':
  324. self.attn = MHA_pro(config, layer_id)
  325. self.mlp = RWKV_ChannelMix(config, layer_id)
  326. def forward(self, x):
  327. x = x + self.attn(self.ln1(x))
  328. x = x + self.mlp(self.ln2(x))
  329. return x
  330. class GPT(nn.Module):
  331. def __init__(self, config):
  332. super().__init__()
  333. self.config = config
  334. self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
  335. self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])
  336. self.ln_f = nn.LayerNorm(config.n_embd)
  337. self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
  338. self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  339. self.head_q = nn.Linear(config.n_embd, 256)
  340. self.head_q.scale_init = 0.01
  341. self.head_k = nn.Linear(config.n_embd, 256)
  342. self.head_k.scale_init = 0.01
  343. self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
  344. self.ctx_len = config.ctx_len
  345. if self.config.model_type == 'RWKV':
  346. RWKV_Init(self, config)
  347. else:
  348. self.apply(self._init_weights)
  349. logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
  350. def get_ctx_len(self):
  351. return self.ctx_len
  352. def _init_weights(self, module):
  353. if isinstance(module, (nn.Linear, nn.Embedding)):
  354. module.weight.data.normal_(mean=0.0, std=0.01)
  355. if isinstance(module, nn.Linear) and module.bias is not None:
  356. module.bias.data.zero_()
  357. def configure_optimizers(self, train_config):
  358. # separate out all parameters to those that will and won't experience regularizing weight decay
  359. decay = set()
  360. no_decay = set()
  361. whitelist_weight_modules = (nn.Linear, )
  362. blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
  363. for mn, m in self.named_modules():
  364. for pn, p in m.named_parameters():
  365. fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
  366. if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
  367. no_decay.add(fpn)
  368. elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
  369. decay.add(fpn)
  370. elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
  371. no_decay.add(fpn)
  372. # validate that we considered every parameter
  373. param_dict = {pn: p for pn, p in self.named_parameters()}
  374. inter_params = decay & no_decay
  375. union_params = decay | no_decay
  376. assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
  377. assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
  378. % (str(param_dict.keys() - union_params), )
  379. optim_groups = [
  380. {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
  381. {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
  382. ]
  383. optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
  384. return optimizer
  385. def forward(self, idx, targets=None):
  386. B, T = idx.size()
  387. assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
  388. x = self.tok_emb(idx)
  389. x = self.blocks(x)
  390. x = self.ln_f(x)
  391. q = self.head_q(x)[:,:T,:]
  392. k = self.head_k(x)[:,:T,:]
  393. c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
  394. c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
  395. c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
  396. x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
  397. x = self.head(x) + c
  398. loss = None
  399. if targets is not None:
  400. loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
  401. return x, loss