deepcontext_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import json
  7. import math
  8. from codecs import open
  9. from collections import Counter
  10. import numpy as np
  11. import torch
  12. import torch.nn as nn
  13. from loguru import logger
  14. # Define constants associated with the usual special tokens.
  15. SOS_TOKEN = '<sos>'
  16. EOS_TOKEN = '<eos>'
  17. UNK_TOKEN = '<unk>'
  18. PAD_TOKEN = '<pad>'
  19. class NegativeSampling(nn.Module):
  20. def __init__(self,
  21. embed_size,
  22. counter,
  23. n_negatives,
  24. power,
  25. device,
  26. ignore_index):
  27. super(NegativeSampling, self).__init__()
  28. self.counter = counter
  29. self.n_negatives = n_negatives
  30. self.power = power
  31. self.device = device
  32. self.W = nn.Embedding(
  33. num_embeddings=len(counter),
  34. embedding_dim=embed_size,
  35. padding_idx=ignore_index
  36. )
  37. self.W.weight.data.zero_()
  38. self.logsigmoid = nn.LogSigmoid()
  39. self.sampler = WalkerAlias(np.power(counter, power))
  40. def negative_sampling(self, shape):
  41. if self.n_negatives > 0:
  42. return torch.tensor(self.sampler.sample(shape=shape), dtype=torch.long, device=self.device)
  43. else:
  44. raise NotImplementedError
  45. def forward(self, sentence, context):
  46. batch_size, seq_len = sentence.size()
  47. emb = self.W(sentence)
  48. pos_loss = self.logsigmoid((emb * context).sum(2))
  49. neg_samples = self.negative_sampling(shape=(batch_size, seq_len, self.n_negatives))
  50. neg_emb = self.W(neg_samples)
  51. neg_loss = self.logsigmoid((-neg_emb * context.unsqueeze(2)).sum(3)).sum(2)
  52. return -(pos_loss + neg_loss).sum()
  53. class WalkerAlias:
  54. """
  55. This is from Chainer's implementation.
  56. You can find the original code at
  57. https://github.com/chainer/chainer/blob/v4.4.0/chainer/utils/walker_alias.py
  58. This class is
  59. Copyright (c) 2015 Preferred Infrastructure, Inc.
  60. Copyright (c) 2015 Preferred Networks, Inc.
  61. """
  62. def __init__(self, probs):
  63. prob = np.array(probs, np.float32)
  64. prob /= np.sum(prob)
  65. threshold = np.ndarray(len(probs), np.float32)
  66. values = np.ndarray(len(probs) * 2, np.int32)
  67. il, ir = 0, 0
  68. pairs = list(zip(prob, range(len(probs))))
  69. pairs.sort()
  70. for prob, i in pairs:
  71. p = prob * len(probs)
  72. while p > 1 and ir < il:
  73. values[ir * 2 + 1] = i
  74. p -= 1.0 - threshold[ir]
  75. ir += 1
  76. threshold[il] = p
  77. values[il * 2] = i
  78. il += 1
  79. # fill the rest
  80. for i in range(ir, len(probs)):
  81. values[i * 2 + 1] = 0
  82. assert ((values < len(threshold)).all())
  83. self.threshold = threshold
  84. self.values = values
  85. def sample(self, shape):
  86. ps = np.random.uniform(0, 1, shape)
  87. pb = ps * len(self.threshold)
  88. index = pb.astype(np.int32)
  89. left_right = (self.threshold[index] < pb - index).astype(np.int32)
  90. return self.values[index * 2 + left_right]
  91. class Context2vec(nn.Module):
  92. def __init__(
  93. self,
  94. vocab_size,
  95. counter,
  96. word_embed_size,
  97. hidden_size,
  98. n_layers,
  99. use_mlp,
  100. dropout,
  101. pad_index,
  102. device,
  103. is_inference
  104. ):
  105. super(Context2vec, self).__init__()
  106. self.vocab_size = vocab_size
  107. self.word_embed_size = word_embed_size
  108. self.hidden_size = hidden_size
  109. self.n_layers = n_layers
  110. self.use_mlp = use_mlp
  111. self.device = device
  112. self.is_inference = is_inference
  113. self.rnn_output_size = hidden_size
  114. self.drop = nn.Dropout(dropout)
  115. self.l2r_emb = nn.Embedding(
  116. num_embeddings=vocab_size,
  117. embedding_dim=word_embed_size,
  118. padding_idx=pad_index
  119. )
  120. self.l2r_rnn = nn.LSTM(
  121. input_size=word_embed_size,
  122. hidden_size=hidden_size,
  123. num_layers=n_layers,
  124. batch_first=True
  125. )
  126. self.r2l_emb = nn.Embedding(
  127. num_embeddings=vocab_size,
  128. embedding_dim=word_embed_size,
  129. padding_idx=pad_index
  130. )
  131. self.r2l_rnn = nn.LSTM(
  132. input_size=word_embed_size,
  133. hidden_size=hidden_size,
  134. num_layers=n_layers,
  135. batch_first=True
  136. )
  137. self.criterion = NegativeSampling(
  138. hidden_size,
  139. counter,
  140. ignore_index=pad_index,
  141. n_negatives=10,
  142. power=0.75,
  143. device=device
  144. )
  145. if use_mlp:
  146. self.MLP = MLP(
  147. input_size=hidden_size * 2,
  148. mid_size=hidden_size * 2,
  149. output_size=hidden_size,
  150. dropout=dropout
  151. )
  152. else:
  153. self.weights = nn.Parameter(torch.zeros(2, hidden_size))
  154. self.gamma = nn.Parameter(torch.ones(1))
  155. self.init_weights()
  156. def init_weights(self):
  157. std = math.sqrt(1. / self.word_embed_size)
  158. self.r2l_emb.weight.data.normal_(0, std)
  159. self.l2r_emb.weight.data.normal_(0, std)
  160. def forward(self, sentences, target, target_pos=None):
  161. # input: <BOS> a b c <EOS>
  162. # reversed_sentences: <EOS> c b a
  163. # sentences: <BOS> a b c
  164. reversed_sentences = sentences.flip(1)[:, :-1]
  165. sentences = sentences[:, :-1]
  166. l2r_emb = self.l2r_emb(sentences)
  167. r2l_emb = self.r2l_emb(reversed_sentences)
  168. output_l2r, _ = self.l2r_rnn(l2r_emb)
  169. output_r2l, _ = self.r2l_rnn(r2l_emb)
  170. # output_l2r: h(<BOS>) h(a) h(b)
  171. # output_r2l: h(b) h(c) h(<EOS>)
  172. output_l2r = output_l2r[:, :-1, :]
  173. output_r2l = output_r2l[:, :-1, :].flip(1)
  174. if self.is_inference:
  175. # user_input: I like [] .
  176. # target_pos: 2 (starts from zero)
  177. # output_l2r: h(<BOS>) h(I) h(like) h([])
  178. # output_r2l: h(like) h([]) h(.) h(<EOS>)
  179. # output_l2r[target_pos]: h(like)
  180. # output_r2l[target_pos]: h(.)
  181. if self.use_mlp:
  182. output_l2r = output_l2r[0, target_pos]
  183. output_r2l = output_r2l[0, target_pos]
  184. c_i = self.MLP(torch.cat((output_l2r, output_r2l), dim=0))
  185. return c_i
  186. else:
  187. # on a training phase
  188. if self.use_mlp:
  189. c_i = self.MLP(torch.cat((output_l2r, output_r2l), dim=2))
  190. else:
  191. s_task = torch.nn.functional.softmax(self.weights, dim=0)
  192. c_i = torch.stack((output_l2r, output_r2l), dim=2) * s_task
  193. c_i = self.gamma * c_i.sum(2)
  194. loss = self.criterion(target, c_i)
  195. return loss
  196. def init_hidden(self, batch_size):
  197. weight = next(self.parameters())
  198. return (weight.new_zeros(self.n_layers, batch_size, self.hidden_size),
  199. weight.new_zeros(self.n_layers, batch_size, self.hidden_size))
  200. def run_inference(self, input_tokens, target, target_pos, topk=10):
  201. context_vector = self.forward(input_tokens, target=None, target_pos=target_pos)
  202. if target is None:
  203. topv, topi = ((self.criterion.W.weight * context_vector).sum(dim=1)).data.topk(topk)
  204. return topv, topi
  205. else:
  206. context_vector /= torch.norm(context_vector, p=2)
  207. target_vector = self.criterion.W.weight[target]
  208. target_vector /= torch.norm(target_vector, p=2)
  209. similarity = (target_vector * context_vector).sum()
  210. return similarity.item()
  211. def norm_embedding_weight(self, embedding_module):
  212. embedding_module.weight.data /= torch.norm(embedding_module.weight.data, p=2, dim=1, keepdim=True)
  213. # replace NaN with zero
  214. embedding_module.weight.data[embedding_module.weight.data != embedding_module.weight.data] = 0
  215. class MLP(nn.Module):
  216. def __init__(
  217. self,
  218. input_size,
  219. mid_size,
  220. output_size,
  221. n_layers=2,
  222. dropout=0.3,
  223. activation_function='relu'
  224. ):
  225. super(MLP, self).__init__()
  226. self.input_size = input_size
  227. self.mid_size = mid_size
  228. self.output_size = output_size
  229. self.n_layers = n_layers
  230. self.drop = nn.Dropout(dropout)
  231. self.MLP = nn.ModuleList()
  232. if n_layers == 1:
  233. self.MLP.append(nn.Linear(input_size, output_size))
  234. else:
  235. self.MLP.append(nn.Linear(input_size, mid_size))
  236. for _ in range(n_layers - 2):
  237. self.MLP.append(nn.Linear(mid_size, mid_size))
  238. self.MLP.append(nn.Linear(mid_size, output_size))
  239. if activation_function == 'tanh':
  240. self.activation_function = nn.Tanh()
  241. elif activation_function == 'relu':
  242. self.activation_function = nn.ReLU()
  243. else:
  244. raise NotImplementedError
  245. def forward(self, x):
  246. out = x
  247. for i in range(self.n_layers - 1):
  248. out = self.MLP[i](self.drop(out))
  249. out = self.activation_function(out)
  250. return self.MLP[-1](self.drop(out))
  251. def save_word_dict(dict_data, save_path):
  252. with open(save_path, 'w', encoding='utf-8') as f:
  253. for k, v in dict_data.items():
  254. f.write("%s\t%d\n" % (k, v))
  255. def load_word_dict(save_path):
  256. dict_data = dict()
  257. with open(save_path, 'r', encoding='utf-8') as f:
  258. for line in f:
  259. line = line.strip('\n')
  260. items = line.split('\t')
  261. try:
  262. dict_data[items[0]] = int(items[1])
  263. except Exception as e:
  264. logger.warning(f"Exception: {e}, {line}")
  265. return dict_data
  266. def read_vocab(input_texts, max_size=None, min_count=0):
  267. token_counts = Counter()
  268. special_tokens = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]
  269. for texts in input_texts:
  270. for token in texts:
  271. token_counts.update(token)
  272. # Sort word count by value
  273. count_pairs = token_counts.most_common()
  274. vocab = [k for k, v in count_pairs if v >= min_count]
  275. word_freq = {k: v for k, v in count_pairs if v >= min_count}
  276. # Insert the special tokens to the beginning
  277. vocab = special_tokens + vocab
  278. if max_size is not None:
  279. vocab = vocab[:max_size]
  280. vocab2id = dict(zip(vocab, range(len(vocab))))
  281. special_tokens_dict = {k: 0 for k in special_tokens}
  282. word_freq.update(special_tokens_dict)
  283. return vocab2id, word_freq
  284. def get_minibatches(n, minibatch_size, shuffle=True):
  285. idx_list = np.arange(0, n, minibatch_size) # [0, 1, ..., n-1]
  286. if shuffle:
  287. np.random.shuffle(idx_list)
  288. minibatches = []
  289. for idx in idx_list:
  290. minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
  291. return minibatches
  292. def prepare_data(seqs, max_length=512):
  293. seqs = [seq[:max_length] for seq in seqs]
  294. lengths = [len(seq) for seq in seqs]
  295. n_samples = len(seqs)
  296. x = np.zeros((n_samples, max_length)).astype('int32')
  297. x_lengths = np.array(lengths).astype("int32")
  298. for idx, seq in enumerate(seqs):
  299. x[idx, :lengths[idx]] = seq
  300. return x, x_lengths # x_mask
  301. def gen_examples(src_sentences, batch_size, max_length):
  302. minibatches = get_minibatches(len(src_sentences), batch_size)
  303. examples = []
  304. for minibatch in minibatches:
  305. mb_src_sentences = [src_sentences[t] for t in minibatch]
  306. mb_x, mb_x_len = prepare_data(mb_src_sentences, max_length)
  307. examples.append((mb_x, mb_x_len))
  308. return examples
  309. def one_hot(src_sentences, src_dict, sort_by_len=False):
  310. """vector the sequences."""
  311. out_src_sentences = [[src_dict.get(w, 0) for w in sent] for sent in src_sentences]
  312. # sort sentences by english lengths
  313. def len_argsort(seq):
  314. return sorted(range(len(seq)), key=lambda x: len(seq[x]))
  315. # sort length
  316. if sort_by_len:
  317. sorted_index = len_argsort(out_src_sentences)
  318. out_src_sentences = [out_src_sentences[i] for i in sorted_index]
  319. return out_src_sentences
  320. def write_embedding(id2word, nn_embedding, use_cuda, filename):
  321. with open(filename, mode='w', encoding='utf-8') as f:
  322. f.write('{} {}\n'.format(nn_embedding.num_embeddings, nn_embedding.embedding_dim))
  323. if use_cuda:
  324. embeddings = nn_embedding.weight.data.cpu().numpy()
  325. else:
  326. embeddings = nn_embedding.weight.data.numpy()
  327. for word_id, vec in enumerate(embeddings):
  328. word = id2word[word_id]
  329. vec = ' '.join(list(map(str, vec)))
  330. f.write('{} {}\n'.format(word, vec))
  331. def write_config(filename, **kwargs):
  332. with open(filename, mode='w', encoding='utf-8') as f:
  333. json.dump(kwargs, f)
  334. def read_config(filename):
  335. with open(filename, mode='r', encoding='utf-8') as f:
  336. return json.load(f)
  337. class ContextDataset:
  338. def __init__(
  339. self,
  340. train_path,
  341. batch_size=64,
  342. max_length=512,
  343. min_freq=0,
  344. device='cuda',
  345. vocab_path='vocab.txt',
  346. vocab_max_size=50000,
  347. pad_token=PAD_TOKEN,
  348. unk_token=UNK_TOKEN,
  349. sos_token=SOS_TOKEN,
  350. eos_token=EOS_TOKEN,
  351. ):
  352. sentences = []
  353. with open(train_path, 'r', encoding='utf-8') as f:
  354. for line in f:
  355. tokens = list(line.strip().lower())
  356. if len(tokens) > 0:
  357. sentences.append([sos_token] + tokens + [eos_token])
  358. self.sent_dict = self._gathered_by_lengths(sentences)
  359. self.pad_token = pad_token
  360. self.unk_token = unk_token
  361. self.sos_token = sos_token
  362. self.eos_token = eos_token
  363. self.device = device
  364. self.vocab_2_ids, self.word_freqs = read_vocab(sentences, max_size=vocab_max_size, min_count=min_freq)
  365. logger.debug(f"vocab_2_ids size: {len(self.vocab_2_ids)}, word_freqs: {len(self.word_freqs)}, "
  366. f"vocab_2_ids head: {list(self.vocab_2_ids.items())[:10]}, "
  367. f"word_freqs head: {list(self.word_freqs.items())[:10]}")
  368. save_word_dict(self.vocab_2_ids, vocab_path)
  369. self.id_2_vocabs = {v: k for k, v in self.vocab_2_ids.items()}
  370. self.train_data = gen_examples(one_hot(sentences, self.vocab_2_ids), batch_size, max_length)
  371. self.pad_index = self.vocab_2_ids.get(self.pad_token, 0)
  372. def _gathered_by_lengths(self, sentences):
  373. lengths = [(index, len(sent)) for index, sent in enumerate(sentences)]
  374. lengths = sorted(lengths, key=lambda x: x[1], reverse=True)
  375. sent_dict = dict()
  376. current_length = -1
  377. for (index, length) in lengths:
  378. if current_length == length:
  379. sent_dict[length].append(index)
  380. else:
  381. sent_dict[length] = [index]
  382. current_length = length
  383. return sent_dict