Add RetryProvider

pull/924/head
Heiner Lohaus 1 year ago
parent 951a1332a7
commit e9f96ced9c

@ -38,10 +38,14 @@ from .FastGpt import FastGpt
from .V50 import V50
from .Wuguokai import Wuguokai
from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider
from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider
from .retry_provider import RetryProvider
__all__ = [
'BaseProvider',
'AsyncProvider',
'AsyncGeneratorProvider',
'RetryProvider',
'Acytoo',
'Aichat',
'Ails',

@ -0,0 +1,81 @@
from __future__ import annotations
import random
from ..typing import CreateResult
from .base_provider import BaseProvider, AsyncProvider
class RetryProvider(AsyncProvider):
__name__ = "RetryProvider"
working = True
needs_auth = False
supports_stream = True
supports_gpt_35_turbo = False
supports_gpt_4 = False
def __init__(
self,
providers: list[type[BaseProvider]],
shuffle: bool = True
) -> None:
self.providers = providers
self.shuffle = shuffle
def create_completion(
self,
model: str,
messages: list[dict[str, str]],
stream: bool = False,
**kwargs
) -> CreateResult:
if stream:
providers = [provider for provider in self.providers if provider.supports_stream]
else:
providers = self.providers
if self.shuffle:
random.shuffle(providers)
self.exceptions = {}
started = False
for provider in providers:
try:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
self.exceptions[provider.__name__] = e
if started:
break
self.raise_exceptions()
async def create_async(
self,
model: str,
messages: list[dict[str, str]],
**kwargs
) -> str:
providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)]
if self.shuffle:
random.shuffle(providers)
self.exceptions = {}
for provider in providers:
try:
return await provider.create_async(model, messages, **kwargs)
except Exception as e:
self.exceptions[provider.__name__] = e
self.raise_exceptions()
def raise_exceptions(self):
if self.exceptions:
raise RuntimeError("\n".join(["All providers failed:"] + [
f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions
]))
raise RuntimeError("No provider found")

@ -14,13 +14,7 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP
raise Exception(f'The model: {model} does not exist')
if not provider:
if isinstance(model.best_provider, list):
if stream:
provider = random.choice([p for p in model.best_provider if p.supports_stream])
else:
provider = random.choice(model.best_provider)
else:
provider = model.best_provider
provider = model.best_provider
if not provider:
raise Exception(f'No provider found for model: {model}')
@ -70,7 +64,7 @@ class ChatCompletion:
model, provider = get_model_and_provider(model, provider, False)
if not issubclass(provider, AsyncProvider):
if not issubclass(type(provider), AsyncProvider):
raise Exception(f"Provider: {provider.__name__} doesn't support create_async")
return await provider.create_async(model.name, messages, **kwargs)

@ -1,17 +1,23 @@
from __future__ import annotations
from dataclasses import dataclass
from .typing import Union
from .Provider import BaseProvider
from .Provider import BaseProvider, RetryProvider
from .Provider import (
ChatgptLogin,
CodeLinkAva,
ChatgptAi,
ChatBase,
Vercel,
DeepAi,
Aivvm,
Bard,
H2o
H2o,
GptGo,
Bing,
PerplexityAi,
Wewordle,
Yqcloud,
AItianhu,
Aichat,
)
@dataclass(unsafe_hash=True)
@ -24,15 +30,24 @@ class Model:
# Works for Liaobots, H2o, OpenaiChat, Yqcloud, You
default = Model(
name = "",
base_provider = "huggingface")
base_provider = "",
best_provider = RetryProvider([
Bing, # Not fully GPT 3 or 4
PerplexityAi, # Adds references to sources
Wewordle, # Responds with markdown
Yqcloud, # Answers short questions in chinese
ChatBase, # Don't want to answer creatively
DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat,
])
)
# GPT-3.5 / GPT-4
gpt_35_turbo = Model(
name = 'gpt-3.5-turbo',
base_provider = 'openai',
best_provider = [
DeepAi, CodeLinkAva, ChatgptLogin, ChatgptAi, ChatBase, Aivvm
]
best_provider = RetryProvider([
DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat,
])
)
gpt_4 = Model(

Loading…
Cancel
Save