gpt4free/g4f/Provider/base_provider.py

138 lines
3.3 KiB
Python
Raw Normal View History

from __future__ import annotations
2023-07-28 10:07:17 +00:00
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod
2023-07-28 10:07:17 +00:00
from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import AsyncGenerator, CreateResult
2023-07-28 10:07:17 +00:00
class BaseProvider(ABC):
url: str
2023-10-07 08:17:43 +00:00
working: bool = False
needs_auth: bool = False
supports_stream: bool = False
supports_gpt_35_turbo: bool = False
supports_gpt_4: bool = False
2023-07-28 10:07:17 +00:00
@staticmethod
@abstractmethod
def create_completion(
model: str,
messages: list[dict[str, str]],
stream: bool,
**kwargs
) -> CreateResult:
2023-07-28 10:07:17 +00:00
raise NotImplementedError()
@classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
*,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
if not loop:
loop = get_event_loop()
2023-10-07 08:17:43 +00:00
def create_func() -> str:
return "".join(cls.create_completion(
model,
messages,
False,
**kwargs
))
2023-10-07 08:17:43 +00:00
return await loop.run_in_executor(
executor,
create_func
)
2023-07-28 10:07:17 +00:00
@classmethod
@property
2023-10-07 08:17:43 +00:00
def params(cls) -> str:
2023-07-28 10:07:17 +00:00
params = [
("model", "str"),
("messages", "list[dict[str, str]]"),
("stream", "bool"),
]
param = ", ".join([": ".join(p) for p in params])
2023-09-17 21:23:54 +00:00
return f"g4f.provider.{cls.__name__} supports: ({param})"
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
2023-09-18 05:15:43 +00:00
stream: bool = False,
**kwargs
) -> CreateResult:
loop = get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@staticmethod
@abstractmethod
async def create_async(
model: str,
2023-09-20 15:31:25 +00:00
messages: list[dict[str, str]],
**kwargs
) -> str:
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
supports_stream = True
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = True,
**kwargs
) -> CreateResult:
loop = get_event_loop()
generator = cls.create_async_generator(
model,
messages,
stream=stream,
**kwargs
)
2023-10-07 08:17:43 +00:00
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
2023-09-18 05:15:43 +00:00
@classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
**kwargs
) -> str:
2023-09-20 15:31:25 +00:00
return "".join([
chunk async for chunk in cls.create_async_generator(
model,
messages,
stream=False,
**kwargs
)
])
2023-10-07 08:17:43 +00:00
@staticmethod
@abstractmethod
def create_async_generator(
2023-09-18 05:15:43 +00:00
model: str,
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
raise NotImplementedError()