gorilla_falcon_cli.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. """
  2. Chat with a model with command line interface.
  3. Usage:
  4. python3 -m gorilla_cli --model path/to/gorilla-7b-hf-v0
  5. Thanks to LMSYS for the template of this code.
  6. """
  7. import argparse
  8. import gc
  9. import os
  10. import re
  11. import sys
  12. import abc
  13. import torch
  14. from transformers import (
  15. AutoConfig,
  16. AutoModel,
  17. AutoModelForCausalLM,
  18. AutoModelForSeq2SeqLM,
  19. AutoTokenizer,
  20. LlamaTokenizer,
  21. LlamaForCausalLM,
  22. T5Tokenizer,
  23. )
  24. from prompt_toolkit import PromptSession
  25. from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
  26. from prompt_toolkit.completion import WordCompleter
  27. from prompt_toolkit.history import InMemoryHistory
  28. from conv_template import get_conv_template
  29. import warnings
  30. warnings.filterwarnings('ignore')
  31. # Load Gorilla Model from HF
  32. def load_model(
  33. model_path: str,
  34. device: str,
  35. num_gpus: int,
  36. max_gpu_memory: str = None,
  37. load_8bit: bool = False,
  38. cpu_offloading: bool = False,
  39. ):
  40. if device == "cpu":
  41. kwargs = {"torch_dtype": torch.float32}
  42. elif device == "cuda":
  43. kwargs = {"torch_dtype": torch.float16}
  44. if num_gpus != 1:
  45. kwargs["device_map"] = "auto"
  46. if max_gpu_memory is None:
  47. kwargs[
  48. "device_map"
  49. ] = "sequential" # This is important for not the same VRAM sizes
  50. available_gpu_memory = get_gpu_memory(num_gpus)
  51. kwargs["max_memory"] = {
  52. i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
  53. for i in range(num_gpus)
  54. }
  55. else:
  56. kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
  57. else:
  58. raise ValueError(f"Invalid device: {device}")
  59. if cpu_offloading:
  60. # raises an error on incompatible platforms
  61. from transformers import BitsAndBytesConfig
  62. if "max_memory" in kwargs:
  63. kwargs["max_memory"]["cpu"] = (
  64. str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
  65. )
  66. kwargs["quantization_config"] = BitsAndBytesConfig(
  67. load_in_8bit_fp32_cpu_offload=cpu_offloading
  68. )
  69. kwargs["load_in_8bit"] = load_8bit
  70. elif load_8bit:
  71. if num_gpus != 1:
  72. warnings.warn(
  73. "8-bit quantization is not supported for multi-gpu inference."
  74. )
  75. else:
  76. return load_compress_model(
  77. model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
  78. )
  79. tokenizer = AutoTokenizer.from_pretrained(model_path)
  80. tokenizer.pad_token = tokenizer.eos_token
  81. tokenizer.pad_token_id = 11
  82. model = AutoModelForCausalLM.from_pretrained(
  83. model_path,
  84. trust_remote_code=True,
  85. low_cpu_mem_usage=True,
  86. **kwargs,
  87. )
  88. return model, tokenizer
  89. @torch.inference_mode()
  90. def get_response(prompt, model, tokenizer, device):
  91. input_ids = tokenizer([prompt]).input_ids
  92. output_ids = model.generate(
  93. torch.as_tensor(input_ids).to(device),
  94. do_sample=True,
  95. temperature=0.7,
  96. max_new_tokens=1024,
  97. pad_token_id=tokenizer.eos_token_id
  98. )
  99. output_ids = output_ids[0][len(input_ids[0]) :]
  100. outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
  101. yield {"text": outputs}
  102. # clean
  103. gc.collect()
  104. torch.cuda.empty_cache()
  105. class SimpleChatIO(abc.ABC):
  106. def prompt_for_input(self, role) -> str:
  107. return input(f"{role}: ")
  108. def prompt_for_output(self, role: str):
  109. print(f"{role}: ", end="", flush=True)
  110. def stream_output(self, output_stream):
  111. pre = 0
  112. for outputs in output_stream:
  113. output_text = outputs["text"]
  114. output_text = output_text.strip().split(" ")
  115. now = len(output_text) - 1
  116. if now > pre:
  117. print(" ".join(output_text[pre:now]), end=" ", flush=True)
  118. pre = now
  119. print(" ".join(output_text[pre:]), flush=True)
  120. return " ".join(output_text)
  121. def chat_loop(
  122. model_path: str,
  123. device: str,
  124. num_gpus: int,
  125. max_gpu_memory: str,
  126. load_8bit: bool,
  127. cpu_offloading: bool,
  128. chatio: abc.ABC,
  129. ):
  130. # Model
  131. model, tokenizer = load_model(
  132. model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading
  133. )
  134. if (args.device == "cuda" and args.num_gpus == 1 and not args.cpu_offloading) or args.device == "mps":
  135. model.to(args.device)
  136. while True:
  137. # Chat
  138. if "falcon" in model_path:
  139. conv = get_conv_template("falcon")
  140. elif "mpt" in model_path:
  141. conv = get_conv_template("mpt")
  142. else:
  143. conv = get_conv_template("gorilla_v0")
  144. try:
  145. inp = chatio.prompt_for_input(conv.roles[0])
  146. except EOFError:
  147. inp = ""
  148. if not inp:
  149. print("exit...")
  150. break
  151. conv.append_message(conv.roles[0], inp)
  152. conv.append_message(conv.roles[1], None)
  153. prompt = conv.get_prompt()
  154. chatio.prompt_for_output(conv.roles[1])
  155. output_stream = get_response(prompt, model, tokenizer, device)
  156. outputs = chatio.stream_output(output_stream)
  157. conv.update_last_message(outputs.strip())
  158. def main(args):
  159. if args.gpus:
  160. if len(args.gpus.split(",")) < args.num_gpus:
  161. raise ValueError(
  162. f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
  163. )
  164. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
  165. chatio = SimpleChatIO()
  166. try:
  167. chat_loop(
  168. args.model_path,
  169. args.device,
  170. args.num_gpus,
  171. args.max_gpu_memory,
  172. args.load_8bit,
  173. args.cpu_offloading,
  174. chatio,
  175. )
  176. except KeyboardInterrupt:
  177. print("exit...")
  178. if __name__ == "__main__":
  179. parser = argparse.ArgumentParser()
  180. parser.add_argument(
  181. "--model-path", type=str, default=None,
  182. help="Model path to the pretrained model."
  183. )
  184. parser.add_argument(
  185. "--gpus", type=str, default=None,
  186. help="A single GPU like 1 or multiple GPUs like 0,2."
  187. )
  188. parser.add_argument(
  189. "--num-gpus",
  190. type=int,
  191. default=1)
  192. parser.add_argument(
  193. "--device", type=str, default='cuda',
  194. help="Which device to use."
  195. )
  196. parser.add_argument(
  197. "--max-gpu-memory",
  198. type=str,
  199. help="The maximum memory per gpu. Use a string like '13Gib'",
  200. )
  201. parser.add_argument(
  202. "--load-8bit", action="store_true", help="Use 8-bit quantization"
  203. )
  204. parser.add_argument(
  205. "--cpu-offloading",
  206. action="store_true",
  207. help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
  208. )
  209. args = parser.parse_args()
  210. main(args)