You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpt4free/g4f/Provider/base_provider.py

170 lines
4.5 KiB
Python

from __future__ import annotations
10 months ago
import sys
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
10 months ago
from abc import ABC, abstractmethod
from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages
10 months ago
if sys.version_info < (3, 10):
NoneType = type(None)
else:
from types import NoneType
class BaseProvider(ABC):
url: str
11 months ago
working: bool = False
needs_auth: bool = False
supports_stream: bool = False
supports_gpt_35_turbo: bool = False
supports_gpt_4: bool = False
supports_message_history: bool = False
@staticmethod
@abstractmethod
def create_completion(
model: str,
11 months ago
messages: Messages,
stream: bool,
**kwargs
) -> CreateResult:
raise NotImplementedError()
@classmethod
async def create_async(
cls,
model: str,
11 months ago
messages: Messages,
*,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
if not loop:
loop = get_event_loop()
11 months ago
def create_func() -> str:
return "".join(cls.create_completion(
model,
messages,
False,
**kwargs
))
11 months ago
return await loop.run_in_executor(
executor,
create_func
)
@classmethod
@property
11 months ago
def params(cls) -> str:
if issubclass(cls, AsyncGeneratorProvider):
sig = signature(cls.create_async_generator)
elif issubclass(cls, AsyncProvider):
sig = signature(cls.create_async)
else:
sig = signature(cls.create_completion)
def get_type_name(annotation: type) -> str:
if hasattr(annotation, "__name__"):
annotation = annotation.__name__
elif isinstance(annotation, NoneType):
annotation = "None"
return str(annotation)
10 months ago
args = ""
for name, param in sig.parameters.items():
if name in ("self", "kwargs"):
continue
if name == "stream" and not cls.supports_stream:
continue
if args:
args += ", "
args += "\n"
args += " " + name
if name != "model" and param.annotation is not Parameter.empty:
args += f": {get_type_name(param.annotation)}"
if param.default == "":
args += ' = ""'
elif param.default is not Parameter.empty:
args += f" = {param.default}"
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
cls,
model: str,
11 months ago
messages: Messages,
stream: bool = False,
**kwargs
) -> CreateResult:
loop = get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@staticmethod
@abstractmethod
async def create_async(
model: str,
11 months ago
messages: Messages,
**kwargs
) -> str:
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
supports_stream = True
@classmethod
def create_completion(
cls,
model: str,
11 months ago
messages: Messages,
stream: bool = True,
**kwargs
) -> CreateResult:
loop = get_event_loop()
generator = cls.create_async_generator(
model,
messages,
stream=stream,
**kwargs
)
11 months ago
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
@classmethod
async def create_async(
cls,
model: str,
11 months ago
messages: Messages,
**kwargs
) -> str:
return "".join([
chunk async for chunk in cls.create_async_generator(
model,
messages,
stream=False,
**kwargs
)
])
11 months ago
@staticmethod
@abstractmethod
def create_async_generator(
model: str,
11 months ago
messages: Messages,
**kwargs
11 months ago
) -> AsyncResult:
raise NotImplementedError()