NexraLLaMA31.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ...typing import AsyncResult, Messages
  5. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  6. from ..helper import format_prompt
  7. class NexraLLaMA31(AsyncGeneratorProvider, ProviderModelMixin):
  8. label = "Nexra LLaMA 3.1"
  9. url = "https://nexra.aryahcr.cc/documentation/llama-3.1/en"
  10. api_endpoint = "https://nexra.aryahcr.cc/api/chat/complements"
  11. working = True
  12. supports_stream = True
  13. default_model = 'llama-3.1'
  14. models = [default_model]
  15. model_aliases = {
  16. "llama-3.1-8b": "llama-3.1",
  17. }
  18. @classmethod
  19. def get_model(cls, model: str) -> str:
  20. if model in cls.models:
  21. return model
  22. elif model in cls.model_aliases:
  23. return cls.model_aliases.get(model, cls.default_model)
  24. else:
  25. return cls.default_model
  26. @classmethod
  27. async def create_async_generator(
  28. cls,
  29. model: str,
  30. messages: Messages,
  31. proxy: str = None,
  32. stream: bool = False,
  33. markdown: bool = False,
  34. **kwargs
  35. ) -> AsyncResult:
  36. model = cls.get_model(model)
  37. headers = {
  38. "Content-Type": "application/json"
  39. }
  40. async with ClientSession(headers=headers) as session:
  41. prompt = format_prompt(messages)
  42. data = {
  43. "messages": [
  44. {
  45. "role": "user",
  46. "content": prompt
  47. }
  48. ],
  49. "stream": stream,
  50. "markdown": markdown,
  51. "model": model
  52. }
  53. async with session.post(f"{cls.api_endpoint}", json=data, proxy=proxy) as response:
  54. response.raise_for_status()
  55. if stream:
  56. # Streamed response handling
  57. collected_message = ""
  58. async for chunk in response.content.iter_any():
  59. if chunk:
  60. decoded_chunk = chunk.decode().strip().split("\x1e")
  61. for part in decoded_chunk:
  62. if part:
  63. message_data = json.loads(part)
  64. # Collect messages until 'finish': true
  65. if 'message' in message_data and message_data['message']:
  66. collected_message = message_data['message']
  67. # When finish is true, yield the final collected message
  68. if message_data.get('finish', False):
  69. yield collected_message
  70. return
  71. else:
  72. # Non-streamed response handling
  73. response_data = await response.json(content_type=None)
  74. # Yield the message directly from the response
  75. if 'message' in response_data and response_data['message']:
  76. yield response_data['message']
  77. return