|
|
|
@ -5,6 +5,7 @@ import uuid
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import base64
|
|
|
|
|
import time
|
|
|
|
|
from aiohttp import ClientWebSocketResponse
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -47,7 +48,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|
|
|
|
_api_key: str = None
|
|
|
|
|
_headers: dict = None
|
|
|
|
|
_cookies: Cookies = None
|
|
|
|
|
_last_message: int = 0
|
|
|
|
|
_expires: int = None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def create(
|
|
|
|
@ -348,7 +349,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|
|
|
|
timeout=timeout
|
|
|
|
|
) as session:
|
|
|
|
|
# Read api_key and cookies from cache / browser config
|
|
|
|
|
if cls._headers is None:
|
|
|
|
|
if cls._headers is None or time.time() > cls._expires:
|
|
|
|
|
if api_key is None:
|
|
|
|
|
# Read api_key from cookies
|
|
|
|
|
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
|
|
|
|
@ -437,17 +438,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|
|
|
|
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
|
|
|
|
|
async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator:
|
|
|
|
|
while True:
|
|
|
|
|
yield base64.b64decode((await ws.receive_json())["body"])
|
|
|
|
|
message = await ws.receive_json()
|
|
|
|
|
if message["conversation_id"] == conversation_id:
|
|
|
|
|
yield base64.b64decode(message["body"])
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
|
|
|
|
|
last_message: int = 0
|
|
|
|
|
async for message in messages:
|
|
|
|
|
if message.startswith(b'{"wss_url":'):
|
|
|
|
|
async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
|
|
|
|
|
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
|
|
|
|
|
message = json.loads(message)
|
|
|
|
|
async with session.ws_connect(message["wss_url"]) as ws:
|
|
|
|
|
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields):
|
|
|
|
|
yield chunk
|
|
|
|
|
break
|
|
|
|
|
async for chunk in cls.iter_messages_line(session, message, fields):
|
|
|
|
@ -589,6 +593,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|
|
|
|
@classmethod
|
|
|
|
|
def _set_api_key(cls, api_key: str):
|
|
|
|
|
cls._api_key = api_key
|
|
|
|
|
cls._expires = int(time.time()) + 60 * 60 * 4
|
|
|
|
|
cls._headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|