You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpt4free/g4f/Provider/retry_provider.py

92 lines
2.9 KiB
Python

12 months ago
from __future__ import annotations
11 months ago
import asyncio
12 months ago
import random
11 months ago
from typing import List, Type, Dict
11 months ago
from ..typing import CreateResult, Messages
12 months ago
from .base_provider import BaseProvider, AsyncProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
12 months ago
class RetryProvider(AsyncProvider):
11 months ago
__name__: str = "RetryProvider"
supports_stream: bool = True
12 months ago
def __init__(
self,
11 months ago
providers: List[Type[BaseProvider]],
12 months ago
shuffle: bool = True
) -> None:
11 months ago
self.providers: List[Type[BaseProvider]] = providers
self.shuffle: bool = shuffle
self.working = True
12 months ago
def create_completion(
self,
model: str,
11 months ago
messages: Messages,
12 months ago
stream: bool = False,
**kwargs
) -> CreateResult:
if stream:
providers = [provider for provider in self.providers if provider.supports_stream]
else:
providers = self.providers
if self.shuffle:
random.shuffle(providers)
11 months ago
self.exceptions: Dict[str, Exception] = {}
started: bool = False
12 months ago
for provider in providers:
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
11 months ago
12 months ago
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
11 months ago
12 months ago
if started:
return
11 months ago
12 months ago
except Exception as e:
self.exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
12 months ago
if started:
11 months ago
raise e
12 months ago
self.raise_exceptions()
async def create_async(
self,
model: str,
11 months ago
messages: Messages,
12 months ago
**kwargs
) -> str:
11 months ago
providers = self.providers
12 months ago
if self.shuffle:
random.shuffle(providers)
11 months ago
self.exceptions: Dict[str, Exception] = {}
12 months ago
for provider in providers:
try:
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60)
)
12 months ago
except Exception as e:
self.exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
12 months ago
self.raise_exceptions()
11 months ago
def raise_exceptions(self) -> None:
12 months ago
if self.exceptions:
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
12 months ago
]))
raise RetryNoProviderError("No provider found")