tokenization_gptpangu.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. import sentencepiece
  3. import jieba
  4. import numpy as np
  5. from transformers.tokenization_utils import PreTrainedTokenizer
  6. jieba.add_word('<s>')
  7. jieba.add_word('</s>')
  8. jieba.add_word('<eot>')
  9. jieba.add_word('<unk>')
  10. jieba.add_word('<sep>')
  11. jieba.add_word('<pad>')
  12. class GPTPanguTokenizer(PreTrainedTokenizer):
  13. # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
  14. vocab_files_names = {
  15. "model_file": "vocab.model"
  16. }
  17. def __init__(
  18. self,
  19. model_file,
  20. **kwargs
  21. ):
  22. self.sp = sentencepiece.SentencePieceProcessor()
  23. self.sp.Load(model_file=model_file)
  24. self.translator = str.maketrans(" \n", "\u2582\u2583")
  25. super().__init__(**kwargs)
  26. # special token ids
  27. # self.eos_token_id = self.sp.piece_to_id("<eot>")
  28. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  29. """
  30. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  31. adding special tokens. A BERT sequence has the following format:
  32. - single sequence: `[CLS] X [SEP]`
  33. - pair of sequences: `[CLS] A [SEP] B [SEP]`
  34. Args:
  35. token_ids_0 (`List[int]`):
  36. List of IDs to which the special tokens will be added.
  37. token_ids_1 (`List[int]`, *optional*):
  38. Optional second list of IDs for sequence pairs.
  39. Returns:
  40. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  41. """
  42. if self.bos_token_id is not None:
  43. if token_ids_1 is None:
  44. return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
  45. bos = [self.bos_token_id]
  46. sep = [self.sep_token_id]
  47. eos = [self.eos_token_id]
  48. return bos + token_ids_0 + sep + token_ids_1 + eos
  49. else:
  50. if token_ids_1 is None:
  51. return token_ids_0 + [self.eos_token_id]
  52. sep = [self.sep_token_id]
  53. eos = [self.eos_token_id]
  54. return token_ids_0 + sep + token_ids_1 + eos
  55. def tokenize(self, text, **kwargs):
  56. """ Tokenize a string. """
  57. seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
  58. return seg_list
  59. def convert_tokens_to_ids(self, tokens):
  60. if tokens is None:
  61. return None
  62. if isinstance(tokens, str):
  63. return self._convert_token_to_id_with_added_voc(tokens)
  64. special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]
  65. ids = []
  66. i = 0
  67. for j in special_tokens_index:
  68. new_seg = " ".join(tokens[i:j])
  69. ids.extend(self.sp.encode(new_seg))
  70. ids.append(self._convert_token_to_id(tokens[j]))
  71. i = j + 1
  72. new_seg = " ".join(tokens[i:])
  73. ids.extend(self.sp.encode(new_seg))
  74. return ids
  75. # new_seg = " ".join(tokens)
  76. # return self.sp.encode(new_seg)
  77. # # return tokens
  78. def _convert_token_to_id(self, token):
  79. return self.sp.piece_to_id(token)
  80. def _convert_id_to_token(self, index):
  81. return self.sp.id_to_piece(index)
  82. def convert_ids_to_tokens(self, ids):
  83. return self.decode(ids)
  84. def decode(self, ids, **kwargs):
  85. if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
  86. ids = ids.tolist()
  87. if kwargs.get('skip_special_tokens', None) is True:
  88. ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
  89. text = self.sp.decode(ids)
  90. if isinstance(text, list):
  91. text = text[0]
  92. text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
  93. return text
  94. def get_vocab(self):
  95. vocab = {self.sp.IdToPiece(i): i for i in range(self.sp.GetPieceSize())}
  96. return vocab
  97. @property
  98. def vocab_size(self) -> int:
  99. """
  100. `int`: Size of the base vocabulary (without the added tokens).
  101. """
  102. return len(self.sp)