deepcontext_model.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description: context to vector network
  5. """
  6. import os
  7. import time
  8. from typing import List, Optional
  9. import numpy as np
  10. import pandas as pd
  11. import torch
  12. from loguru import logger
  13. from torch import optim
  14. from pycorrector.deepcontext.deepcontext_utils import (
  15. Context2vec,
  16. read_config,
  17. load_word_dict,
  18. write_config,
  19. ContextDataset
  20. )
  21. class DeepContextModel:
  22. def __init__(self, model_dir: str, max_length: int = 1024, use_cuda: Optional[bool] = True):
  23. if use_cuda:
  24. if torch.cuda.is_available():
  25. self.device = torch.device("cuda:0")
  26. else:
  27. raise ValueError(
  28. "'use_cuda' set to True when cuda is unavailable."
  29. "Make sure CUDA is available or set `use_cuda=False`."
  30. )
  31. else:
  32. if torch.backends.mps.is_available():
  33. self.device = torch.device("mps")
  34. else:
  35. self.device = "cpu"
  36. logger.debug(f"Device: {self.device}")
  37. self.config_file = os.path.join(model_dir, 'config.json')
  38. self.checkpoint_file = os.path.join(model_dir, "pytorch_model.bin")
  39. self.optimizer_file = os.path.join(model_dir, 'optimizer.pt')
  40. self.vocab_file = os.path.join(model_dir, 'vocab.txt')
  41. self.model_dir = model_dir
  42. self.max_length = max_length
  43. self.mask = "[]"
  44. self.model = None
  45. self.optimizer = None
  46. self.config_dict = None
  47. self.stoi = None
  48. self.itos = None
  49. def load_model(self):
  50. if not os.path.exists(self.config_file):
  51. raise ValueError('config file not exists.')
  52. if not os.path.exists(self.checkpoint_file):
  53. raise ValueError('checkpoint file not exists.')
  54. if not os.path.exists(self.vocab_file):
  55. raise ValueError('vocab file not exists.')
  56. config_dict = read_config(self.config_file)
  57. self.model = Context2vec(
  58. vocab_size=config_dict['vocab_size'],
  59. counter=[1] * config_dict['vocab_size'],
  60. word_embed_size=config_dict['word_embed_size'],
  61. hidden_size=config_dict['hidden_size'],
  62. n_layers=config_dict['n_layers'],
  63. use_mlp=config_dict['use_mlp'],
  64. dropout=config_dict['dropout'],
  65. pad_index=config_dict['pad_index'],
  66. device=self.device,
  67. is_inference=True
  68. ).to(self.device)
  69. self.model.load_state_dict(torch.load(self.checkpoint_file, map_location=self.device))
  70. self.optimizer = optim.Adam(self.model.parameters(), lr=config_dict['learning_rate'])
  71. if os.path.exists(self.optimizer_file):
  72. self.optimizer.load_state_dict(torch.load(self.optimizer_file, map_location=self.device))
  73. self.config_dict = config_dict
  74. # read vocab
  75. self.stoi = load_word_dict(self.vocab_file)
  76. self.itos = {v: k for k, v in self.stoi.items()}
  77. def train_model(
  78. self,
  79. train_path,
  80. batch_size=64,
  81. num_epochs=3,
  82. word_embed_size=200,
  83. hidden_size=200,
  84. learning_rate=1e-3,
  85. n_layers=2,
  86. min_freq=1,
  87. vocab_max_size=50000,
  88. dropout=0.0
  89. ):
  90. if not os.path.isfile(train_path):
  91. raise FileNotFoundError
  92. os.makedirs(self.model_dir, exist_ok=True)
  93. logger.info('Loading data')
  94. dataset = ContextDataset(
  95. train_path,
  96. batch_size=batch_size,
  97. max_length=self.max_length,
  98. min_freq=min_freq,
  99. device=self.device,
  100. vocab_path=self.vocab_file,
  101. vocab_max_size=vocab_max_size,
  102. )
  103. counter = np.array([dataset.word_freqs[word] for word in dataset.vocab_2_ids])
  104. model = Context2vec(
  105. vocab_size=len(dataset.vocab_2_ids),
  106. counter=counter,
  107. word_embed_size=word_embed_size,
  108. hidden_size=hidden_size,
  109. n_layers=n_layers,
  110. use_mlp=True,
  111. dropout=dropout,
  112. pad_index=dataset.pad_index,
  113. device=self.device,
  114. is_inference=False
  115. ).to(self.device)
  116. if self.model is None:
  117. # norm weight
  118. model.norm_embedding_weight(model.criterion.W)
  119. if self.optimizer is None:
  120. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  121. else:
  122. optimizer = self.optimizer
  123. logger.info(
  124. f'model: {model}, batch_size: {batch_size}, epochs: {num_epochs}, '
  125. f'word_embed_size: {word_embed_size}, hidden_size: {hidden_size}, learning_rate: {learning_rate}'
  126. )
  127. # save config
  128. write_config(
  129. self.config_file,
  130. vocab_size=len(dataset.vocab_2_ids),
  131. word_embed_size=word_embed_size,
  132. hidden_size=hidden_size,
  133. n_layers=n_layers,
  134. use_mlp=True,
  135. dropout=dropout,
  136. pad_index=dataset.pad_index,
  137. pad_token=dataset.pad_token,
  138. unk_token=dataset.unk_token,
  139. sos_token=dataset.sos_token,
  140. eos_token=dataset.eos_token,
  141. learning_rate=learning_rate
  142. )
  143. interval = 1e5
  144. best_loss = 1e3
  145. global_step = 0
  146. training_progress_scores = {
  147. "epoch": [],
  148. "global_step": [],
  149. "train_loss": [],
  150. }
  151. logger.info("train start...")
  152. for epoch in range(num_epochs):
  153. begin_time = time.time()
  154. cur_at = begin_time
  155. total_loss = 0.0
  156. word_count = 0
  157. next_count = interval
  158. last_accum_loss = 0.0
  159. last_word_count = 0
  160. cur_loss = 0
  161. for it, (mb_x, mb_x_len) in enumerate(dataset.train_data):
  162. sentence = torch.from_numpy(mb_x).to(self.device).long()
  163. target = sentence[:, 1:-1]
  164. if target.size(0) == 0:
  165. continue
  166. optimizer.zero_grad()
  167. loss = model(sentence, target)
  168. loss.backward()
  169. optimizer.step()
  170. global_step += 1
  171. total_loss += loss.data.mean()
  172. minibatch_size, sentence_length = target.size()
  173. word_count += minibatch_size * sentence_length
  174. accum_mean_loss = float(total_loss) / word_count if total_loss > 0.0 else 0.0
  175. cur_mean_loss = (float(total_loss) - last_accum_loss) / (word_count - last_word_count)
  176. cur_loss = cur_mean_loss
  177. if word_count >= next_count:
  178. now = time.time()
  179. duration = now - cur_at
  180. throuput = float((word_count - last_word_count)) / (now - cur_at)
  181. logger.info('{} words, {:.2f} sec, {:.2f} words/sec, {:.4f} accum_loss/word, {:.4f} cur_loss/word'
  182. .format(word_count, duration, throuput, accum_mean_loss, cur_mean_loss))
  183. next_count += interval
  184. cur_at = now
  185. last_accum_loss = float(total_loss)
  186. last_word_count = word_count
  187. # find best model
  188. is_best = cur_loss < best_loss
  189. best_loss = min(cur_loss, best_loss)
  190. logger.info('epoch: {}/{}, global_step: {}, loss: {}, best_loss: {}'.format(
  191. epoch + 1, num_epochs, global_step, cur_loss, best_loss))
  192. training_progress_scores["epoch"].append(epoch + 1)
  193. training_progress_scores["global_step"].append(global_step)
  194. training_progress_scores["train_loss"].append(cur_loss)
  195. report = pd.DataFrame(training_progress_scores)
  196. report.to_csv(os.path.join(self.model_dir, "training_progress_scores.csv"), index=False)
  197. if is_best:
  198. self.save_model(model_dir=self.model_dir, model=model, optimizer=optimizer)
  199. logger.info('save new model: {}'.format(epoch + 1, self.model_dir))
  200. def save_model(self, model_dir=None, model=None, optimizer=None):
  201. """Save the model and the optim."""
  202. if not model_dir:
  203. model_dir = self.model_dir
  204. os.makedirs(model_dir, exist_ok=True)
  205. if model:
  206. # Take care of distributed/parallel training
  207. torch.save(model.state_dict(), self.checkpoint_file)
  208. if optimizer:
  209. torch.save(optimizer.state_dict(), self.optimizer_file)
  210. def predict_mask_token(self, tokens: List[str], mask_index: int = 0, topk: int = 10):
  211. if not self.model:
  212. self.load_model()
  213. unk_token = self.config_dict['unk_token']
  214. sos_token = self.config_dict['sos_token']
  215. eos_token = self.config_dict['eos_token']
  216. pad_token = self.config_dict['pad_token']
  217. pred_words = []
  218. tokens[mask_index] = unk_token
  219. tokens = [sos_token] + tokens + [eos_token]
  220. indexed_sentence = [self.stoi[token] if token in self.stoi else self.stoi[unk_token] for token in tokens]
  221. input_tokens = torch.tensor(indexed_sentence, dtype=torch.long, device=self.device).unsqueeze(0)
  222. topv, topi = self.model.run_inference(input_tokens, target=None, target_pos=mask_index, topk=topk)
  223. for value, key in zip(topv, topi):
  224. score = value.item()
  225. word = self.itos[key.item()]
  226. if word in [unk_token, sos_token, eos_token, pad_token]:
  227. continue
  228. pred_words.append((word, score))
  229. return pred_words