123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import torch
- import sentencepiece
- import jieba
- import numpy as np
- from transformers.tokenization_utils import PreTrainedTokenizer
- jieba.add_word('<s>')
- jieba.add_word('</s>')
- jieba.add_word('<eot>')
- jieba.add_word('<unk>')
- jieba.add_word('<sep>')
- jieba.add_word('<pad>')
- class GPTPanguTokenizer(PreTrainedTokenizer):
- # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
- vocab_files_names = {
- "model_file": "vocab.model"
- }
- def __init__(
- self,
- model_file,
- **kwargs
- ):
-
- self.sp = sentencepiece.SentencePieceProcessor()
- self.sp.Load(model_file=model_file)
- self.translator = str.maketrans(" \n", "\u2582\u2583")
- super().__init__(**kwargs)
- # special token ids
- # self.eos_token_id = self.sp.piece_to_id("<eot>")
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
- """
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
- adding special tokens. A BERT sequence has the following format:
- - single sequence: `[CLS] X [SEP]`
- - pair of sequences: `[CLS] A [SEP] B [SEP]`
- Args:
- token_ids_0 (`List[int]`):
- List of IDs to which the special tokens will be added.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
- Returns:
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
- """
- if self.bos_token_id is not None:
- if token_ids_1 is None:
- return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
- bos = [self.bos_token_id]
- sep = [self.sep_token_id]
- eos = [self.eos_token_id]
- return bos + token_ids_0 + sep + token_ids_1 + eos
- else:
- if token_ids_1 is None:
- return token_ids_0 + [self.eos_token_id]
- sep = [self.sep_token_id]
- eos = [self.eos_token_id]
- return token_ids_0 + sep + token_ids_1 + eos
- def tokenize(self, text, **kwargs):
- """ Tokenize a string. """
- seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
- return seg_list
- def convert_tokens_to_ids(self, tokens):
- if tokens is None:
- return None
- if isinstance(tokens, str):
- return self._convert_token_to_id_with_added_voc(tokens)
- special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]
- ids = []
- i = 0
- for j in special_tokens_index:
- new_seg = " ".join(tokens[i:j])
- ids.extend(self.sp.encode(new_seg))
- ids.append(self._convert_token_to_id(tokens[j]))
- i = j + 1
- new_seg = " ".join(tokens[i:])
- ids.extend(self.sp.encode(new_seg))
- return ids
- # new_seg = " ".join(tokens)
- # return self.sp.encode(new_seg)
- # # return tokens
- def _convert_token_to_id(self, token):
- return self.sp.piece_to_id(token)
- def _convert_id_to_token(self, index):
- return self.sp.id_to_piece(index)
- def convert_ids_to_tokens(self, ids):
- return self.decode(ids)
- def decode(self, ids, **kwargs):
- if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
- ids = ids.tolist()
- if kwargs.get('skip_special_tokens', None) is True:
- ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
- text = self.sp.decode(ids)
- if isinstance(text, list):
- text = text[0]
- text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
- return text
- def get_vocab(self):
- vocab = {self.sp.IdToPiece(i): i for i in range(self.sp.GetPieceSize())}
- return vocab
-
- @property
- def vocab_size(self) -> int:
- """
- `int`: Size of the base vocabulary (without the added tokens).
- """
- return len(self.sp)
|