provider.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import annotations
  2. import os
  3. from gpt4all import GPT4All
  4. from .models import get_models
  5. from ..typing import Messages
  6. MODEL_LIST: dict[str, dict] = None
  7. def find_model_dir(model_file: str) -> str:
  8. local_dir = os.path.dirname(os.path.abspath(__file__))
  9. project_dir = os.path.dirname(os.path.dirname(local_dir))
  10. new_model_dir = os.path.join(project_dir, "models")
  11. new_model_file = os.path.join(new_model_dir, model_file)
  12. if os.path.isfile(new_model_file):
  13. return new_model_dir
  14. old_model_dir = os.path.join(local_dir, "models")
  15. old_model_file = os.path.join(old_model_dir, model_file)
  16. if os.path.isfile(old_model_file):
  17. return old_model_dir
  18. working_dir = "./"
  19. for root, dirs, files in os.walk(working_dir):
  20. if model_file in files:
  21. return root
  22. return new_model_dir
  23. class LocalProvider:
  24. @staticmethod
  25. def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs):
  26. global MODEL_LIST
  27. if MODEL_LIST is None:
  28. MODEL_LIST = get_models()
  29. if model not in MODEL_LIST:
  30. raise ValueError(f'Model "{model}" not found / not yet implemented')
  31. model = MODEL_LIST[model]
  32. model_file = model["path"]
  33. model_dir = find_model_dir(model_file)
  34. if not os.path.isfile(os.path.join(model_dir, model_file)):
  35. print(f'Model file "models/{model_file}" not found.')
  36. download = input(f"Do you want to download {model_file}? [y/n]: ")
  37. if download in ["y", "Y"]:
  38. GPT4All.download_model(model_file, model_dir)
  39. else:
  40. raise ValueError(f'Model "{model_file}" not found.')
  41. model = GPT4All(model_name=model_file,
  42. #n_threads=8,
  43. verbose=False,
  44. allow_download=False,
  45. model_path=model_dir)
  46. system_message = "\n".join(message["content"] for message in messages if message["role"] == "system")
  47. if system_message:
  48. system_message = "A chat between a curious user and an artificial intelligence assistant."
  49. prompt_template = "USER: {0}\nASSISTANT: "
  50. conversation = "\n" . join(
  51. f"{message['role'].upper()}: {message['content']}"
  52. for message in messages
  53. if message["role"] != "system"
  54. ) + "\nASSISTANT: "
  55. def should_not_stop(token_id: int, token: str):
  56. return "USER" not in token
  57. with model.chat_session(system_message, prompt_template):
  58. if stream:
  59. for token in model.generate(conversation, streaming=True, callback=should_not_stop):
  60. yield token
  61. else:
  62. yield model.generate(conversation, callback=should_not_stop)