utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. ########################################################################################################
  2. # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
  3. ########################################################################################################
  4. import json
  5. import random
  6. import time
  7. import math
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. from torch.nn import functional as F
  12. from torch.utils.data import Dataset
  13. class Dataset(Dataset):
  14. def __init__(self, data, ctx_len, epoch_length_fixed):
  15. print('building token list...', end=' ')
  16. unique = sorted(list(set(data)))
  17. # print()
  18. # for u in unique:
  19. # print(u, end=' ')
  20. # print('\n\n')
  21. xx = 0
  22. xxObj = {}
  23. for u in unique:
  24. xxObj[xx] = u
  25. xx += 1
  26. with open('vocab.json', "w", encoding="utf-16") as vocab_file:
  27. vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
  28. data_size, vocab_size = len(data), len(unique)
  29. print('data has %d tokens, %d unique.' % (data_size, vocab_size))
  30. self.stoi = {ch: i for i, ch in enumerate(unique)}
  31. self.itos = {i: ch for i, ch in enumerate(unique)}
  32. self.ctx_len = ctx_len
  33. self.epoch_length_fixed = epoch_length_fixed
  34. self.vocab_size = vocab_size
  35. self.data = data
  36. def __len__(self):
  37. return self.epoch_length_fixed
  38. def __getitem__(self, idx):
  39. # cheat: pick a random spot in dataset
  40. i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
  41. chunk = self.data[i:i+self.ctx_len+1]
  42. dix = [self.stoi[s] for s in chunk]
  43. x = torch.tensor(dix[:-1], dtype=torch.long,
  44. device=torch.device('cuda'))
  45. y = torch.tensor(dix[1:], dtype=torch.long,
  46. device=torch.device('cuda'))
  47. return x, y
  48. class TOKENIZER():
  49. def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
  50. with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
  51. self.word_table = json.load(result_file)
  52. self.vocab_size = len(self.word_table)
  53. self.stoi = {v: int(k) for k, v in self.word_table.items()}
  54. self.itos = {int(k): v for k, v in self.word_table.items()}
  55. self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
  56. def refine_context(self, context):
  57. context = context.strip().split('\n')
  58. for c in range(len(context)):
  59. context[c] = context[c].strip().strip('\u3000').strip('\r')
  60. context = list(filter(lambda c: c != '', context))
  61. context = '\n' + ('\n'.join(context)).strip()
  62. if context == '':
  63. context = '\n'
  64. return context
  65. def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
  66. # out[self.UNKNOWN_CHAR] = -float('Inf')
  67. lastChar = int(x[-1])
  68. probs = F.softmax(torch.tensor(out), dim=-1)
  69. if self.itos[lastChar] == '\n':
  70. top_p = top_p_newline
  71. else:
  72. top_p = top_p_usual
  73. sorted_probs, s_index = torch.sort(probs, descending=True)
  74. # for j in range(30):
  75. # pp = sorted_probs[j].item()
  76. # if pp < 0.005:
  77. # break
  78. # ss = self.itos[int(s_index[j])].replace('\n','_')
  79. # print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
  80. # print('')
  81. cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
  82. cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
  83. probs[probs < cutoff] = 0
  84. # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
  85. if temperature != 1.0:
  86. probs = probs.pow(1.0 / temperature)
  87. return torch.multinomial(probs, num_samples=1)[0]
  88. def to_float(x):
  89. return x.cpu().detach().numpy().flatten()[0].astype(float)
  90. def set_seed(seed):
  91. random.seed(seed)
  92. np.random.seed(seed)
  93. torch.manual_seed(seed)
  94. torch.cuda.manual_seed_all(seed)