Fix g4f/Provider/ReplicateHome.py

This commit is contained in:
kqlio67 2024-09-01 21:05:58 +03:00
parent aa265acf30
commit 33cc1cb1e0

View File

@ -16,7 +16,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
working = True
default_model = 'stability-ai/stable-diffusion-3'
models = [
# Models for image generation
# Models for image generation
'stability-ai/stable-diffusion-3',
'bytedance/sdxl-lightning-4step',
'playgroundai/playground-v2.5-1024px-aesthetic',
@ -28,7 +28,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
]
versions = {
# Model versions for generating images
# Model versions for generating images
'stability-ai/stable-diffusion-3': [
"527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f"
],
@ -39,7 +39,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
"a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24"
],
# Model versions for text generation
'meta/meta-llama-3-70b-instruct': [
"dp-cf04fe09351e25db628e8b6181276547"
@ -55,6 +54,24 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
image_models = {"stability-ai/stable-diffusion-3", "bytedance/sdxl-lightning-4step", "playgroundai/playground-v2.5-1024px-aesthetic"}
text_models = {"meta/meta-llama-3-70b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "google-deepmind/gemma-2b-it"}
model_aliases = {
"stable-diffusion-3": "stability-ai/stable-diffusion-3",
"sdxl-lightning-4step": "bytedance/sdxl-lightning-4step",
"playground-v2.5-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic",
"llama-3-70b": "meta/meta-llama-3-70b-instruct",
"mixtral-8x7b": "mistralai/mixtral-8x7b-instruct-v0.1",
"gemma-2b": "google-deepmind/gemma-2b-it",
}
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model
@classmethod
async def create_async_generator(
cls,
@ -76,6 +93,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
extra_data: Dict[str, Any] = {},
**kwargs: Any
) -> Union[str, ImageResponse]:
model = cls.get_model(model) # Use the get_model method to resolve model name
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
@ -109,7 +127,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
"version": version
}
if api_key is None:
data["model"] = cls.get_model(model)
data["model"] = model
url = "https://homepage.replicate.com/api/prediction"
else:
url = "https://api.replicate.com/v1/predictions"