gorilla_cli.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """
  2. Chat with a model with command line interface.
  3. Usage:
  4. python3 gorilla_cli.py --model-path path/to/gorilla-7b-hf-v0
  5. Thanks to LMSYS for the template of this code.
  6. """
  7. import argparse
  8. import gc
  9. import os
  10. import re
  11. import sys
  12. import abc
  13. import torch
  14. from transformers import (
  15. AutoConfig,
  16. AutoModel,
  17. AutoModelForCausalLM,
  18. AutoModelForSeq2SeqLM,
  19. AutoTokenizer,
  20. LlamaTokenizer,
  21. LlamaForCausalLM,
  22. T5Tokenizer,
  23. )
  24. from transformers.generation.logits_process import (
  25. LogitsProcessorList,
  26. RepetitionPenaltyLogitsProcessor,
  27. TemperatureLogitsWarper,
  28. TopKLogitsWarper,
  29. TopPLogitsWarper,
  30. )
  31. from prompt_toolkit import PromptSession
  32. from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
  33. from prompt_toolkit.completion import WordCompleter
  34. from prompt_toolkit.history import InMemoryHistory
  35. from conv_template import get_conv_template
  36. import warnings
  37. warnings.filterwarnings('ignore')
  38. # Load Gorilla Model from HF
  39. def load_model(
  40. model_path: str,
  41. device: str,
  42. num_gpus: int,
  43. max_gpu_memory: str = None,
  44. load_8bit: bool = False,
  45. cpu_offloading: bool = False,
  46. ):
  47. if device == "cpu":
  48. kwargs = {"torch_dtype": torch.float32}
  49. elif device == "cuda":
  50. kwargs = {"torch_dtype": torch.float16}
  51. if num_gpus != 1:
  52. kwargs["device_map"] = "auto"
  53. if max_gpu_memory is None:
  54. kwargs[
  55. "device_map"
  56. ] = "sequential" # This is important for not the same VRAM sizes
  57. available_gpu_memory = get_gpu_memory(num_gpus)
  58. kwargs["max_memory"] = {
  59. i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
  60. for i in range(num_gpus)
  61. }
  62. else:
  63. kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
  64. elif device == "mps":
  65. kwargs = {"torch_dtype": torch.float16}
  66. else:
  67. raise ValueError(f"Invalid device: {device}")
  68. if cpu_offloading:
  69. # raises an error on incompatible platforms
  70. from transformers import BitsAndBytesConfig
  71. if "max_memory" in kwargs:
  72. kwargs["max_memory"]["cpu"] = (
  73. str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
  74. )
  75. kwargs["quantization_config"] = BitsAndBytesConfig(
  76. load_in_8bit_fp32_cpu_offload=cpu_offloading
  77. )
  78. kwargs["load_in_8bit"] = load_8bit
  79. elif load_8bit:
  80. if num_gpus != 1:
  81. warnings.warn(
  82. "8-bit quantization is not supported for multi-gpu inference."
  83. )
  84. else:
  85. return load_compress_model(
  86. model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
  87. )
  88. tokenizer = AutoTokenizer.from_pretrained(model_path)
  89. model = AutoModelForCausalLM.from_pretrained(
  90. model_path,
  91. trust_remote_code=True,
  92. low_cpu_mem_usage=True,
  93. **kwargs,
  94. )
  95. return model, tokenizer
  96. def prepare_logits_processor(
  97. temperature: float, repetition_penalty: float, top_p: float, top_k: int
  98. ):
  99. processor_list = LogitsProcessorList()
  100. if temperature >= 1e-5 and temperature != 1.0:
  101. processor_list.append(TemperatureLogitsWarper(temperature))
  102. if repetition_penalty > 1.0:
  103. processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
  104. if 1e-8 <= top_p < 1.0:
  105. processor_list.append(TopPLogitsWarper(top_p))
  106. if top_k > 0:
  107. processor_list.append(TopKLogitsWarper(top_k))
  108. return processor_list
  109. @torch.inference_mode()
  110. def get_response(prompt, model, tokenizer, device):
  111. logits_processor = prepare_logits_processor(
  112. 0.1, 0.0, 1.0, -1
  113. )
  114. context_len = 2048
  115. max_new_tokens = 1024
  116. stream_interval=2
  117. input_ids = tokenizer(prompt).input_ids
  118. input_echo_len = len(input_ids)
  119. output_ids = list(input_ids)
  120. max_src_len = context_len - max_new_tokens - 8
  121. input_ids = input_ids[-max_src_len:]
  122. stop_token_ids = [tokenizer.eos_token_id]
  123. past_key_values = out = None
  124. for i in range(max_new_tokens):
  125. if i == 0:
  126. out = model(torch.as_tensor([input_ids], device=device),
  127. use_cache=True)
  128. logits = out.logits
  129. past_key_values = out.past_key_values
  130. else:
  131. out = model(
  132. input_ids=torch.as_tensor([[token]], device=device),
  133. use_cache=True,
  134. past_key_values=past_key_values,
  135. )
  136. logits = out.logits
  137. past_key_values = out.past_key_values
  138. tmp_output_ids = None
  139. last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
  140. probs = torch.softmax(last_token_logits, dim=-1)
  141. token = int(torch.multinomial(probs, num_samples=1))
  142. output_ids.append(token)
  143. if token in stop_token_ids:
  144. stopped = True
  145. else:
  146. stopped = False
  147. if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
  148. tmp_output_ids = output_ids[input_echo_len:]
  149. rfind_start = 0
  150. output = tokenizer.decode(
  151. tmp_output_ids,
  152. skip_special_tokens=True,
  153. spaces_between_special_tokens=False,
  154. )
  155. yield {
  156. "text": output
  157. }
  158. if stopped:
  159. break
  160. yield {"text": output}
  161. # clean
  162. del past_key_values, out
  163. gc.collect()
  164. torch.cuda.empty_cache()
  165. class SimpleChatIO(abc.ABC):
  166. def prompt_for_input(self, role) -> str:
  167. return input(f"{role}: ")
  168. def prompt_for_output(self, role: str):
  169. print(f"{role}: ", end="", flush=True)
  170. def stream_output(self, output_stream):
  171. pre = 0
  172. for outputs in output_stream:
  173. output_text = outputs["text"]
  174. output_text = output_text.strip().split(" ")
  175. now = len(output_text) - 1
  176. if now > pre:
  177. print(" ".join(output_text[pre:now]), end=" ", flush=True)
  178. pre = now
  179. print(" ".join(output_text[pre:]), flush=True)
  180. return " ".join(output_text)
  181. def chat_loop(
  182. model_path: str,
  183. device: str,
  184. num_gpus: int,
  185. max_gpu_memory: str,
  186. load_8bit: bool,
  187. cpu_offloading: bool,
  188. chatio: abc.ABC,
  189. ):
  190. # Model
  191. model, tokenizer = load_model(
  192. model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading
  193. )
  194. if (args.device == "cuda" and args.num_gpus == 1 and not args.cpu_offloading) or args.device == "mps":
  195. model.to(args.device)
  196. while True:
  197. # Chat
  198. if "mpt" in model_path:
  199. conv = get_conv_template("mpt")
  200. elif "gorilla" in model_path:
  201. conv = get_conv_template("gorilla_v0")
  202. try:
  203. inp = chatio.prompt_for_input(conv.roles[0])
  204. except EOFError:
  205. inp = ""
  206. if not inp:
  207. print("exit...")
  208. break
  209. conv.append_message(conv.roles[0], inp)
  210. conv.append_message(conv.roles[1], None)
  211. prompt = conv.get_prompt()
  212. chatio.prompt_for_output(conv.roles[1])
  213. output_stream = get_response(prompt, model, tokenizer, device)
  214. outputs = chatio.stream_output(output_stream)
  215. conv.update_last_message(outputs.strip())
  216. def main(args):
  217. if args.gpus:
  218. if len(args.gpus.split(",")) < args.num_gpus:
  219. raise ValueError(
  220. f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
  221. )
  222. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
  223. chatio = SimpleChatIO()
  224. try:
  225. chat_loop(
  226. args.model_path,
  227. args.device,
  228. args.num_gpus,
  229. args.max_gpu_memory,
  230. args.load_8bit,
  231. args.cpu_offloading,
  232. chatio,
  233. )
  234. except KeyboardInterrupt:
  235. print("exit...")
  236. if __name__ == "__main__":
  237. parser = argparse.ArgumentParser()
  238. parser.add_argument(
  239. "--model-path", type=str, default=None,
  240. help="Model path to the pretrained model."
  241. )
  242. parser.add_argument(
  243. "--gpus", type=str, default=None,
  244. help="A single GPU like 1 or multiple GPUs like 0,2."
  245. )
  246. parser.add_argument(
  247. "--num-gpus",
  248. type=int,
  249. default=1)
  250. parser.add_argument(
  251. "--device", type=str, default='cuda',
  252. help="Which device to use."
  253. )
  254. parser.add_argument(
  255. "--max-gpu-memory",
  256. type=str,
  257. help="The maximum memory per gpu. Use a string like '13Gib'",
  258. )
  259. parser.add_argument(
  260. "--load-8bit", action="store_true", help="Use 8-bit quantization"
  261. )
  262. parser.add_argument(
  263. "--cpu-offloading",
  264. action="store_true",
  265. help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
  266. )
  267. args = parser.parse_args()
  268. main(args)