浏览代码

refactor(ReplicateHome): update model handling and API interaction

kqlio67 1 月之前
父节点
当前提交
d69372a962
共有 2 个文件被更改,包括 120 次插入110 次删除
  1. 101 110
      g4f/Provider/ReplicateHome.py
  2. 19 0
      g4f/models.py

+ 101 - 110
g4f/Provider/ReplicateHome.py

@@ -1,58 +1,60 @@
 from __future__ import annotations
-from typing import Generator, Optional, Dict, Any, Union, List
-import random
+
+import json
 import asyncio
-import base64
+from aiohttp import ClientSession, ContentTypeError
 
-from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
 from ..typing import AsyncResult, Messages
-from ..requests import StreamSession, raise_for_status
-from ..errors import ResponseError
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from .helper import format_prompt
 from ..image import ImageResponse
 
 class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
     url = "https://replicate.com"
-    parent = "Replicate"
+    api_endpoint = "https://homepage.replicate.com/api/prediction"
     working = True
+    supports_stream = True
+    supports_system_message = True
+    supports_message_history = True
+    
     default_model = 'meta/meta-llama-3-70b-instruct'
-    text_models = {"meta/meta-llama-3-70b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "google-deepmind/gemma-2b-it"}
-    image_models = {"stability-ai/stable-diffusion-3", "bytedance/sdxl-lightning-4step", "playgroundai/playground-v2.5-1024px-aesthetic"}
-    models = [
-        *text_models,
-        *image_models
+    
+    text_models = [
+        'meta/meta-llama-3-70b-instruct',
+        'mistralai/mixtral-8x7b-instruct-v0.1',
+        'google-deepmind/gemma-2b-it',
+        'yorickvp/llava-13b',
     ]
 
-    versions = {
-        # Model versions for generating images
-        'stability-ai/stable-diffusion-3': [
-            "527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f"
-        ],
-        'bytedance/sdxl-lightning-4step': [
-            "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f"
-        ],
-        'playgroundai/playground-v2.5-1024px-aesthetic': [
-            "a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24"
-        ],
-        
-        # Model versions for text generation
-        'meta/meta-llama-3-70b-instruct': [
-            "dp-cf04fe09351e25db628e8b6181276547"
-        ],
-        'mistralai/mixtral-8x7b-instruct-v0.1': [
-            "dp-89e00f489d498885048e94f9809fbc76"
-        ],
-        'google-deepmind/gemma-2b-it': [
-            "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626"
-        ]
-    }
+    image_models = [
+        'black-forest-labs/flux-schnell',
+        'stability-ai/stable-diffusion-3',
+        'bytedance/sdxl-lightning-4step',
+        'playgroundai/playground-v2.5-1024px-aesthetic',
+    ]
 
+    models = text_models + image_models
+    
     model_aliases = {
+        "flux-schnell": "black-forest-labs/flux-schnell",
         "sd-3": "stability-ai/stable-diffusion-3",
         "sdxl": "bytedance/sdxl-lightning-4step",
         "playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic",
         "llama-3-70b": "meta/meta-llama-3-70b-instruct",
         "mixtral-8x7b": "mistralai/mixtral-8x7b-instruct-v0.1",
         "gemma-2b": "google-deepmind/gemma-2b-it",
+        "llava-13b": "yorickvp/llava-13b",
+    }
+
+    model_versions = {
+        "meta/meta-llama-3-70b-instruct": "fbfb20b472b2f3bdd101412a9f70a0ed4fc0ced78a77ff00970ee7a2383c575d",
+        "mistralai/mixtral-8x7b-instruct-v0.1": "5d78bcd7a992c4b793465bcdcf551dc2ab9668d12bb7aa714557a21c1e77041c",
+        "google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626",
+        "yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
+        'black-forest-labs/flux-schnell': "f2ab8a5bfe79f02f0789a146cf5e73d2a4ff2684a98c2b303d1e1ff3814271db",
+        'stability-ai/stable-diffusion-3': "527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f",
+        'bytedance/sdxl-lightning-4step': "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
+        'playgroundai/playground-v2.5-1024px-aesthetic': "a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
     }
 
     @classmethod
@@ -69,84 +71,73 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
         cls,
         model: str,
         messages: Messages,
-        **kwargs: Any
-    ) -> Generator[Union[str, ImageResponse], None, None]:
-        yield await cls.create_async(messages[-1]["content"], model, **kwargs)
-
-    @classmethod
-    async def create_async(
-        cls,
-        prompt: str,
-        model: str,
-        api_key: Optional[str] = None,
-        proxy: Optional[str] = None,
-        timeout: int = 180,
-        version: Optional[str] = None,
-        extra_data: Dict[str, Any] = {},
-        **kwargs: Any
-    ) -> Union[str, ImageResponse]:
-        model = cls.get_model(model)  # Use the get_model method to resolve model name
+        proxy: str = None,
+        **kwargs
+    ) -> AsyncResult:
+        model = cls.get_model(model)
+        
         headers = {
-            'Accept-Encoding': 'gzip, deflate, br',
-            'Accept-Language': 'en-US',
-            'Connection': 'keep-alive',
-            'Origin': cls.url,
-            'Referer': f'{cls.url}/',
-            'Sec-Fetch-Dest': 'empty',
-            'Sec-Fetch-Mode': 'cors',
-            'Sec-Fetch-Site': 'same-site',
-            'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
-            'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
-            'sec-ch-ua-mobile': '?0',
-            'sec-ch-ua-platform': '"macOS"',
+            "accept": "*/*",
+            "accept-language": "en-US,en;q=0.9",
+            "cache-control": "no-cache",
+            "content-type": "application/json",
+            "origin": "https://replicate.com",
+            "pragma": "no-cache",
+            "priority": "u=1, i",
+            "referer": "https://replicate.com/",
+            "sec-ch-ua": '"Not;A=Brand";v="24", "Chromium";v="128"',
+            "sec-ch-ua-mobile": "?0",
+            "sec-ch-ua-platform": '"Linux"',
+            "sec-fetch-dest": "empty",
+            "sec-fetch-mode": "cors",
+            "sec-fetch-site": "same-site",
+            "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"
         }
-
-        if version is None:
-            version = random.choice(cls.versions.get(model, []))
-        if api_key is not None:
-            headers["Authorization"] = f"Bearer {api_key}"
-
-        async with StreamSession(
-            proxies={"all": proxy},
-            headers=headers,
-            timeout=timeout
-        ) as session:
+        
+        async with ClientSession(headers=headers) as session:
+            if model in cls.image_models:
+                prompt = messages[-1]['content'] if messages else ""
+            else:
+                prompt = format_prompt(messages)
+            
             data = {
-                "input": {
-                    "prompt": prompt,
-                    **extra_data
-                },
-                "version": version
+                "model": model,
+                "version": cls.model_versions[model],
+                "input": {"prompt": prompt},
             }
-            if api_key is None:
-                data["model"] = model
-                url = "https://homepage.replicate.com/api/prediction"
-            else:
-                url = "https://api.replicate.com/v1/predictions"
-            async with session.post(url, json=data) as response:
-                await raise_for_status(response)
+            
+            async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
+                response.raise_for_status()
                 result = await response.json()
-            if "id" not in result:
-                raise ResponseError(f"Invalid response: {result}")
+                prediction_id = result['id']
+            
+            poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}"
+            max_attempts = 30
+            delay = 5
+            for _ in range(max_attempts):
+                async with session.get(poll_url, proxy=proxy) as response:
+                    response.raise_for_status()
+                    try:
+                        result = await response.json()
+                    except ContentTypeError:
+                        text = await response.text()
+                        try:
+                            result = json.loads(text)
+                        except json.JSONDecodeError:
+                            raise ValueError(f"Unexpected response format: {text}")
 
-            while True:
-                if api_key is None:
-                    url = f"https://homepage.replicate.com/api/poll?id={result['id']}"
-                else:
-                    url = f"https://api.replicate.com/v1/predictions/{result['id']}"
-                async with session.get(url) as response:
-                    await raise_for_status(response)
-                    result = await response.json()
-                    if "status" not in result:
-                        raise ResponseError(f"Invalid response: {result}")
-                    if result["status"] == "succeeded":
-                        output = result['output']
-                        if model in cls.text_models:
-                            return ''.join(output) if isinstance(output, list) else output
-                        elif model in cls.image_models:
-                            images: List[Any] = output
-                            images = images[0] if len(images) == 1 else images
-                            return ImageResponse(images, prompt)
-                    elif result["status"] == "failed":
-                        raise ResponseError(f"Prediction failed: {result}")
-                    await asyncio.sleep(0.5)
+                    if result['status'] == 'succeeded':
+                        if model in cls.image_models:
+                            image_url = result['output'][0]
+                            yield ImageResponse(image_url, "Generated image")
+                            return
+                        else:
+                            for chunk in result['output']:
+                                yield chunk
+                        break
+                    elif result['status'] == 'failed':
+                        raise Exception(f"Prediction failed: {result.get('error')}")
+                await asyncio.sleep(delay)
+            
+            if result['status'] != 'succeeded':
+                raise Exception("Prediction timed out")

+ 19 - 0
g4f/models.py

@@ -489,6 +489,13 @@ sh_n_7b = Model(
     best_provider = Airforce
 )
 
+### Yorickvp ###
+llava_13b = Model(
+    name = 'llava-13b',
+    base_provider = 'Yorickvp',
+    best_provider = ReplicateHome
+)
+
 #############
 ### Image ###
 #############
@@ -559,6 +566,13 @@ flux_pixel = Model(
     
 )
 
+flux_schnell = Model(
+    name = 'flux-schnell',
+    base_provider = 'Flux AI',
+    best_provider = IterListProvider([ReplicateHome])
+    
+)
+
 ### ###
 dalle = Model(
     name = 'dalle',
@@ -746,6 +760,10 @@ class ModelUtils:
 
 ### Together ###
 'sh-n-7b': sh_n_7b,
+      
+        
+### Yorickvp ###
+'llava-13b': llava_13b,
         
         
         
@@ -769,6 +787,7 @@ class ModelUtils:
 'flux-3d': flux_3d,
 'flux-disney': flux_disney,
 'flux-pixel': flux_pixel,
+'flux-schnell': flux_schnell,
 
 
 ###  ###