mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-12 19:10:32 +00:00
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
|
from ...typing import AsyncResult, Messages
|
|
from ...requests.raise_for_status import raise_for_status
|
|
from ...requests import StreamSession
|
|
from ...errors import MissingAuthError
|
|
|
|
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
|
|
url = "https://openai.com"
|
|
working = True
|
|
needs_auth = True
|
|
supports_message_history = True
|
|
supports_system_message = True
|
|
|
|
@classmethod
|
|
async def create_async_generator(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
proxy: str = None,
|
|
timeout: int = 120,
|
|
api_key: str = None,
|
|
api_base: str = "https://api.openai.com/v1",
|
|
temperature: float = None,
|
|
max_tokens: int = None,
|
|
top_p: float = None,
|
|
stop: str = None,
|
|
stream: bool = False,
|
|
**kwargs
|
|
) -> AsyncResult:
|
|
if api_key is None:
|
|
raise MissingAuthError('Add a "api_key"')
|
|
async with StreamSession(
|
|
proxies={"all": proxy},
|
|
headers=cls.get_headers(api_key),
|
|
timeout=timeout
|
|
) as session:
|
|
data = {
|
|
"messages": messages,
|
|
"model": cls.get_model(model),
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
}
|
|
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
|
|
await raise_for_status(response)
|
|
async for line in response.iter_lines():
|
|
if line.startswith(b"data: ") or not stream:
|
|
async for chunk in cls.read_line(line[6:] if stream else line, stream):
|
|
yield chunk
|
|
|
|
@staticmethod
|
|
async def read_line(line: str, stream: bool):
|
|
if line == b"[DONE]":
|
|
return
|
|
choice = json.loads(line)["choices"][0]
|
|
if stream and "content" in choice["delta"] and choice["delta"]["content"]:
|
|
yield choice["delta"]["content"]
|
|
elif not stream and "content" in choice["message"]:
|
|
yield choice["message"]["content"]
|
|
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
|
yield FinishReason(choice["finish_reason"])
|
|
|
|
@staticmethod
|
|
def get_headers(api_key: str) -> dict:
|
|
return {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
} |