HuggingFace.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession, BaseConnector
  4. from ..typing import AsyncResult, Messages
  5. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  6. from .helper import get_connector
  7. from ..errors import RateLimitError, ModelNotFoundError
  8. from ..requests.raise_for_status import raise_for_status
  9. from .HuggingChat import HuggingChat
  10. class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://huggingface.co/chat"
  12. working = True
  13. needs_auth = True
  14. supports_message_history = True
  15. default_model = HuggingChat.default_model
  16. models = HuggingChat.models
  17. model_aliases = HuggingChat.model_aliases
  18. @classmethod
  19. def get_model(cls, model: str) -> str:
  20. if model in cls.models:
  21. return model
  22. elif model in cls.model_aliases:
  23. return cls.model_aliases[model]
  24. else:
  25. return cls.default_model
  26. @classmethod
  27. async def create_async_generator(
  28. cls,
  29. model: str,
  30. messages: Messages,
  31. stream: bool = True,
  32. proxy: str = None,
  33. connector: BaseConnector = None,
  34. api_base: str = "https://api-inference.huggingface.co",
  35. api_key: str = None,
  36. max_new_tokens: int = 1024,
  37. temperature: float = 0.7,
  38. **kwargs
  39. ) -> AsyncResult:
  40. model = cls.get_model(model)
  41. headers = {
  42. 'accept': '*/*',
  43. 'accept-language': 'en',
  44. 'cache-control': 'no-cache',
  45. 'origin': 'https://huggingface.co',
  46. 'pragma': 'no-cache',
  47. 'priority': 'u=1, i',
  48. 'referer': 'https://huggingface.co/chat/',
  49. 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
  50. 'sec-ch-ua-mobile': '?0',
  51. 'sec-ch-ua-platform': '"macOS"',
  52. 'sec-fetch-dest': 'empty',
  53. 'sec-fetch-mode': 'cors',
  54. 'sec-fetch-site': 'same-origin',
  55. 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
  56. }
  57. if api_key is not None:
  58. headers["Authorization"] = f"Bearer {api_key}"
  59. params = {
  60. "return_full_text": False,
  61. "max_new_tokens": max_new_tokens,
  62. "temperature": temperature,
  63. **kwargs
  64. }
  65. payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
  66. async with ClientSession(
  67. headers=headers,
  68. connector=get_connector(connector, proxy)
  69. ) as session:
  70. async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
  71. if response.status == 404:
  72. raise ModelNotFoundError(f"Model is not supported: {model}")
  73. await raise_for_status(response)
  74. if stream:
  75. first = True
  76. async for line in response.content:
  77. if line.startswith(b"data:"):
  78. data = json.loads(line[5:])
  79. if not data["token"]["special"]:
  80. chunk = data["token"]["text"]
  81. if first:
  82. first = False
  83. chunk = chunk.lstrip()
  84. yield chunk
  85. else:
  86. yield (await response.json())[0]["generated_text"].strip()
  87. def format_prompt(messages: Messages) -> str:
  88. system_messages = [message["content"] for message in messages if message["role"] == "system"]
  89. question = " ".join([messages[-1]["content"], *system_messages])
  90. history = "".join([
  91. f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
  92. for idx, message in enumerate(messages)
  93. if message["role"] == "assistant"
  94. ])
  95. return f"{history}<s>[INST] {question} [/INST]"