123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description: context to vector network
- """
- import os
- import time
- from typing import List, Optional
- import numpy as np
- import pandas as pd
- import torch
- from loguru import logger
- from torch import optim
- from pycorrector.deepcontext.deepcontext_utils import (
- Context2vec,
- read_config,
- load_word_dict,
- write_config,
- ContextDataset
- )
- class DeepContextModel:
- def __init__(self, model_dir: str, max_length: int = 1024, use_cuda: Optional[bool] = True):
- if use_cuda:
- if torch.cuda.is_available():
- self.device = torch.device("cuda:0")
- else:
- raise ValueError(
- "'use_cuda' set to True when cuda is unavailable."
- "Make sure CUDA is available or set `use_cuda=False`."
- )
- else:
- if torch.backends.mps.is_available():
- self.device = torch.device("mps")
- else:
- self.device = "cpu"
- logger.debug(f"Device: {self.device}")
- self.config_file = os.path.join(model_dir, 'config.json')
- self.checkpoint_file = os.path.join(model_dir, "pytorch_model.bin")
- self.optimizer_file = os.path.join(model_dir, 'optimizer.pt')
- self.vocab_file = os.path.join(model_dir, 'vocab.txt')
- self.model_dir = model_dir
- self.max_length = max_length
- self.mask = "[]"
- self.model = None
- self.optimizer = None
- self.config_dict = None
- self.stoi = None
- self.itos = None
- def load_model(self):
- if not os.path.exists(self.config_file):
- raise ValueError('config file not exists.')
- if not os.path.exists(self.checkpoint_file):
- raise ValueError('checkpoint file not exists.')
- if not os.path.exists(self.vocab_file):
- raise ValueError('vocab file not exists.')
- config_dict = read_config(self.config_file)
- self.model = Context2vec(
- vocab_size=config_dict['vocab_size'],
- counter=[1] * config_dict['vocab_size'],
- word_embed_size=config_dict['word_embed_size'],
- hidden_size=config_dict['hidden_size'],
- n_layers=config_dict['n_layers'],
- use_mlp=config_dict['use_mlp'],
- dropout=config_dict['dropout'],
- pad_index=config_dict['pad_index'],
- device=self.device,
- is_inference=True
- ).to(self.device)
- self.model.load_state_dict(torch.load(self.checkpoint_file, map_location=self.device))
- self.optimizer = optim.Adam(self.model.parameters(), lr=config_dict['learning_rate'])
- if os.path.exists(self.optimizer_file):
- self.optimizer.load_state_dict(torch.load(self.optimizer_file, map_location=self.device))
- self.config_dict = config_dict
- # read vocab
- self.stoi = load_word_dict(self.vocab_file)
- self.itos = {v: k for k, v in self.stoi.items()}
- def train_model(
- self,
- train_path,
- batch_size=64,
- num_epochs=3,
- word_embed_size=200,
- hidden_size=200,
- learning_rate=1e-3,
- n_layers=2,
- min_freq=1,
- vocab_max_size=50000,
- dropout=0.0
- ):
- if not os.path.isfile(train_path):
- raise FileNotFoundError
- os.makedirs(self.model_dir, exist_ok=True)
- logger.info('Loading data')
- dataset = ContextDataset(
- train_path,
- batch_size=batch_size,
- max_length=self.max_length,
- min_freq=min_freq,
- device=self.device,
- vocab_path=self.vocab_file,
- vocab_max_size=vocab_max_size,
- )
- counter = np.array([dataset.word_freqs[word] for word in dataset.vocab_2_ids])
- model = Context2vec(
- vocab_size=len(dataset.vocab_2_ids),
- counter=counter,
- word_embed_size=word_embed_size,
- hidden_size=hidden_size,
- n_layers=n_layers,
- use_mlp=True,
- dropout=dropout,
- pad_index=dataset.pad_index,
- device=self.device,
- is_inference=False
- ).to(self.device)
- if self.model is None:
- # norm weight
- model.norm_embedding_weight(model.criterion.W)
- if self.optimizer is None:
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
- else:
- optimizer = self.optimizer
- logger.info(
- f'model: {model}, batch_size: {batch_size}, epochs: {num_epochs}, '
- f'word_embed_size: {word_embed_size}, hidden_size: {hidden_size}, learning_rate: {learning_rate}'
- )
- # save config
- write_config(
- self.config_file,
- vocab_size=len(dataset.vocab_2_ids),
- word_embed_size=word_embed_size,
- hidden_size=hidden_size,
- n_layers=n_layers,
- use_mlp=True,
- dropout=dropout,
- pad_index=dataset.pad_index,
- pad_token=dataset.pad_token,
- unk_token=dataset.unk_token,
- sos_token=dataset.sos_token,
- eos_token=dataset.eos_token,
- learning_rate=learning_rate
- )
- interval = 1e5
- best_loss = 1e3
- global_step = 0
- training_progress_scores = {
- "epoch": [],
- "global_step": [],
- "train_loss": [],
- }
- logger.info("train start...")
- for epoch in range(num_epochs):
- begin_time = time.time()
- cur_at = begin_time
- total_loss = 0.0
- word_count = 0
- next_count = interval
- last_accum_loss = 0.0
- last_word_count = 0
- cur_loss = 0
- for it, (mb_x, mb_x_len) in enumerate(dataset.train_data):
- sentence = torch.from_numpy(mb_x).to(self.device).long()
- target = sentence[:, 1:-1]
- if target.size(0) == 0:
- continue
- optimizer.zero_grad()
- loss = model(sentence, target)
- loss.backward()
- optimizer.step()
- global_step += 1
- total_loss += loss.data.mean()
- minibatch_size, sentence_length = target.size()
- word_count += minibatch_size * sentence_length
- accum_mean_loss = float(total_loss) / word_count if total_loss > 0.0 else 0.0
- cur_mean_loss = (float(total_loss) - last_accum_loss) / (word_count - last_word_count)
- cur_loss = cur_mean_loss
- if word_count >= next_count:
- now = time.time()
- duration = now - cur_at
- throuput = float((word_count - last_word_count)) / (now - cur_at)
- logger.info('{} words, {:.2f} sec, {:.2f} words/sec, {:.4f} accum_loss/word, {:.4f} cur_loss/word'
- .format(word_count, duration, throuput, accum_mean_loss, cur_mean_loss))
- next_count += interval
- cur_at = now
- last_accum_loss = float(total_loss)
- last_word_count = word_count
- # find best model
- is_best = cur_loss < best_loss
- best_loss = min(cur_loss, best_loss)
- logger.info('epoch: {}/{}, global_step: {}, loss: {}, best_loss: {}'.format(
- epoch + 1, num_epochs, global_step, cur_loss, best_loss))
- training_progress_scores["epoch"].append(epoch + 1)
- training_progress_scores["global_step"].append(global_step)
- training_progress_scores["train_loss"].append(cur_loss)
- report = pd.DataFrame(training_progress_scores)
- report.to_csv(os.path.join(self.model_dir, "training_progress_scores.csv"), index=False)
- if is_best:
- self.save_model(model_dir=self.model_dir, model=model, optimizer=optimizer)
- logger.info('save new model: {}'.format(epoch + 1, self.model_dir))
- def save_model(self, model_dir=None, model=None, optimizer=None):
- """Save the model and the optim."""
- if not model_dir:
- model_dir = self.model_dir
- os.makedirs(model_dir, exist_ok=True)
- if model:
- # Take care of distributed/parallel training
- torch.save(model.state_dict(), self.checkpoint_file)
- if optimizer:
- torch.save(optimizer.state_dict(), self.optimizer_file)
- def predict_mask_token(self, tokens: List[str], mask_index: int = 0, topk: int = 10):
- if not self.model:
- self.load_model()
- unk_token = self.config_dict['unk_token']
- sos_token = self.config_dict['sos_token']
- eos_token = self.config_dict['eos_token']
- pad_token = self.config_dict['pad_token']
- pred_words = []
- tokens[mask_index] = unk_token
- tokens = [sos_token] + tokens + [eos_token]
- indexed_sentence = [self.stoi[token] if token in self.stoi else self.stoi[unk_token] for token in tokens]
- input_tokens = torch.tensor(indexed_sentence, dtype=torch.long, device=self.device).unsqueeze(0)
- topv, topi = self.model.run_inference(input_tokens, target=None, target_pos=mask_index, topk=topk)
- for value, key in zip(topv, topi):
- score = value.item()
- word = self.itos[key.item()]
- if word in [unk_token, sos_token, eos_token, pad_token]:
- continue
- pred_words.append((word, score))
- return pred_words
|