chatglmoonx.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # ------------------------------------------------------------------------------------------------------------------------
  2. # 🔌💻 Source Code From https://huggingface.co/K024/ChatGLM-6b-onnx-u8s8/blob/main/model.py
  3. # ------------------------------------------------------------------------------------------------------------------------
  4. import re
  5. import numpy as np
  6. # import torch
  7. from onnxruntime import InferenceSession, SessionOptions
  8. # Currently `MatMulInteger` and `DynamicQuantizeLinear` are only supported on CPU,
  9. # although they are documented as supported on CUDA.
  10. providers = ["CPUExecutionProvider"]
  11. # if torch.cuda.is_available():
  12. # providers = ["CUDAExecutionProvider"] + providers
  13. # Default paths
  14. tokenizer_path = "chatglm-6b-int8-onnx-merged/sentencepiece.model"
  15. onnx_model_path = "chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx"
  16. # input & output names
  17. past_names = [f"past_{name}_{i}" for i in range(28) for name in ["key", "value"]]
  18. present_names = [f"present_{name}_{i}" for i in range(28) for name in ["key", "value"]]
  19. output_names = ["logits"] + present_names
  20. # default kv_cache for first inference
  21. default_past_key_values = {
  22. k: np.zeros((1, 0, 32, 128), dtype=np.float32) for k in past_names
  23. }
  24. def chat_template(history: list[tuple[str, str]], current: str):
  25. prompt = ""
  26. chat_round = 0
  27. for question, answer in history:
  28. prompt += f"[Round {chat_round}]\n问:{question}\n答:{answer}\n"
  29. chat_round += 1
  30. prompt += f"[Round {chat_round}]\n问:{current}\n答:"
  31. return prompt
  32. def process_response(response: str):
  33. response = response.strip()
  34. response = response.replace("[[训练时间]]", "2023年")
  35. punkts = [
  36. [",", ","],
  37. ["!", "!"],
  38. [":", ":"],
  39. [";", ";"],
  40. ["\?", "?"],
  41. ]
  42. for item in punkts:
  43. response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
  44. response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
  45. return response
  46. class ChatGLMModel():
  47. def __init__(self, onnx_model_path=onnx_model_path, tokenizer_path=tokenizer_path, profile=False) -> None:
  48. self.tokenizer = ChatGLMTokenizer(tokenizer_path)
  49. options = SessionOptions()
  50. options.enable_profiling = profile
  51. self.session = InferenceSession(onnx_model_path, options, providers=providers)
  52. self.eop_token_id = self.tokenizer["<eop>"]
  53. def prepare_input(self, prompt: str):
  54. input_ids, prefix_mask = self.tokenizer.encode(prompt)
  55. input_ids = np.array([input_ids], dtype=np.longlong)
  56. prefix_mask = np.array([prefix_mask], dtype=np.longlong)
  57. return input_ids, prefix_mask, default_past_key_values
  58. def sample_next_token(self, logits: np.ndarray, top_k=50, top_p=0.7, temperature=1):
  59. # softmax with temperature
  60. exp_logits = np.exp(logits / temperature)
  61. probs = exp_logits / np.sum(exp_logits)
  62. # top k
  63. top_k_idx = np.argsort(-probs)[:top_k]
  64. top_k_probs = probs[top_k_idx]
  65. # top p
  66. cumsum_probs = np.cumsum(top_k_probs)
  67. top_k_probs[(cumsum_probs - top_k_probs) > top_p] = 0.0
  68. top_k_probs = top_k_probs / np.sum(top_k_probs)
  69. # sample
  70. next_token = np.random.choice(top_k_idx, size=1, p=top_k_probs)
  71. return next_token[0].item()
  72. def generate_iterate(self, prompt: str, max_generated_tokens=100, top_k=50, top_p=0.7, temperature=1):
  73. input_ids, prefix_mask, past_key_values = self.prepare_input(prompt)
  74. output_tokens = []
  75. while True:
  76. inputs = {
  77. "input_ids": input_ids,
  78. "prefix_mask": prefix_mask,
  79. "use_past": np.array(len(output_tokens) > 0),
  80. }
  81. inputs.update(past_key_values)
  82. logits, *past_key_values = self.session.run(output_names, inputs)
  83. past_key_values = { k: v for k, v in zip(past_names, past_key_values) }
  84. next_token = self.sample_next_token(logits[0, -1], top_k=top_k, top_p=top_p, temperature=temperature)
  85. output_tokens += [next_token]
  86. if next_token == self.eop_token_id or len(output_tokens) > max_generated_tokens:
  87. break
  88. input_ids = np.array([[next_token]], dtype=np.longlong)
  89. prefix_mask = np.concatenate([prefix_mask, np.array([[0]], dtype=np.longlong)], axis=1)
  90. yield process_response(self.tokenizer.decode(output_tokens))
  91. return process_response(self.tokenizer.decode(output_tokens))
  92. # ------------------------------------------------------------------------------------------------------------------------
  93. # 🔌💻 Source Code From https://huggingface.co/K024/ChatGLM-6b-onnx-u8s8/blob/main/tokenizer.py
  94. # ------------------------------------------------------------------------------------------------------------------------
  95. import re
  96. from sentencepiece import SentencePieceProcessor
  97. def replace_spaces_with_blank(match: re.Match[str]):
  98. return f"<|blank_{len(match.group())}|>"
  99. def replace_blank_with_spaces(match: re.Match[str]):
  100. return " " * int(match.group(1))
  101. class ChatGLMTokenizer:
  102. def __init__(self, vocab_file):
  103. assert vocab_file is not None
  104. self.vocab_file = vocab_file
  105. self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
  106. self.text_tokenizer = SentencePieceProcessor(str(vocab_file))
  107. def __len__(self):
  108. return len(self.text_tokenizer)
  109. def __getitem__(self, key: str):
  110. return self.text_tokenizer[key]
  111. def preprocess(self, text: str, linebreak=True, whitespaces=True):
  112. if linebreak:
  113. text = text.replace("\n", "<n>")
  114. if whitespaces:
  115. text = text.replace("\t", "<|tab|>")
  116. text = re.sub(r" {2,80}", replace_spaces_with_blank, text)
  117. return text
  118. def encode(
  119. self, text: str, text_pair: str = None,
  120. linebreak=True, whitespaces=True,
  121. add_dummy_prefix=True, special_tokens=True,
  122. ) -> tuple[list[int], list[int]]:
  123. """
  124. text: Text to encode. Bidirectional part with a [gMASK] and an <sop> for causal LM.
  125. text_pair: causal LM part.
  126. linebreak: Whether to encode newline (\n) in text.
  127. whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
  128. special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
  129. add_dummy_prefix: Whether to add dummy blank space in the beginning.
  130. """
  131. text = self.preprocess(text, linebreak, whitespaces)
  132. if not add_dummy_prefix:
  133. text = "<n>" + text
  134. tokens = self.text_tokenizer.encode(text)
  135. prefix_mask = [1] * len(tokens)
  136. if special_tokens:
  137. tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer["<sop>"]]
  138. prefix_mask += [1, 0]
  139. if text_pair is not None:
  140. text_pair = self.preprocess(text_pair, linebreak, whitespaces)
  141. pair_tokens = self.text_tokenizer.encode(text_pair)
  142. tokens += pair_tokens
  143. prefix_mask += [0] * len(pair_tokens)
  144. if special_tokens:
  145. tokens += [self.text_tokenizer["<eop>"]]
  146. prefix_mask += [0]
  147. return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask
  148. def decode(self, text_ids: list[int]) -> str:
  149. text = self.text_tokenizer.decode(text_ids)
  150. text = text.replace("<n>", "\n")
  151. text = text.replace("<|tab|>", "\t")
  152. text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text)
  153. return text