|
|
@ -10,22 +10,15 @@ from selenium.webdriver.common.by import By
|
|
|
|
from selenium.webdriver.support.ui import WebDriverWait
|
|
|
|
from selenium.webdriver.support.ui import WebDriverWait
|
|
|
|
from selenium.webdriver.support import expected_conditions as EC
|
|
|
|
from selenium.webdriver.support import expected_conditions as EC
|
|
|
|
|
|
|
|
|
|
|
|
from ..base_provider import AsyncGeneratorProvider
|
|
|
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|
|
|
from ..helper import format_prompt, get_cookies
|
|
|
|
from ..helper import format_prompt, get_cookies
|
|
|
|
from ...webdriver import get_browser, get_driver_cookies
|
|
|
|
from ...webdriver import get_browser, get_driver_cookies
|
|
|
|
from ...typing import AsyncResult, Messages
|
|
|
|
from ...typing import AsyncResult, Messages
|
|
|
|
from ...requests import StreamSession
|
|
|
|
from ...requests import StreamSession
|
|
|
|
from ...image import to_image, to_bytes, ImageType, ImageResponse
|
|
|
|
from ...image import to_image, to_bytes, ImageType, ImageResponse
|
|
|
|
|
|
|
|
|
|
|
|
# Aliases for model names
|
|
|
|
|
|
|
|
MODELS = {
|
|
|
|
|
|
|
|
"gpt-3.5": "text-davinci-002-render-sha",
|
|
|
|
|
|
|
|
"gpt-3.5-turbo": "text-davinci-002-render-sha",
|
|
|
|
|
|
|
|
"gpt-4": "gpt-4",
|
|
|
|
|
|
|
|
"gpt-4-gizmo": "gpt-4-gizmo"
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|
|
|
"""A class for creating and managing conversations with OpenAI chat service"""
|
|
|
|
"""A class for creating and managing conversations with OpenAI chat service"""
|
|
|
|
|
|
|
|
|
|
|
|
url = "https://chat.openai.com"
|
|
|
|
url = "https://chat.openai.com"
|
|
|
@ -33,6 +26,11 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
needs_auth = True
|
|
|
|
needs_auth = True
|
|
|
|
supports_gpt_35_turbo = True
|
|
|
|
supports_gpt_35_turbo = True
|
|
|
|
supports_gpt_4 = True
|
|
|
|
supports_gpt_4 = True
|
|
|
|
|
|
|
|
default_model = None
|
|
|
|
|
|
|
|
models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"]
|
|
|
|
|
|
|
|
model_aliases = {
|
|
|
|
|
|
|
|
"gpt-3.5-turbo": "text-davinci-002-render-sha",
|
|
|
|
|
|
|
|
}
|
|
|
|
_cookies: dict = {}
|
|
|
|
_cookies: dict = {}
|
|
|
|
_default_model: str = None
|
|
|
|
_default_model: str = None
|
|
|
|
|
|
|
|
|
|
|
@ -91,7 +89,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def _upload_image(
|
|
|
|
async def upload_image(
|
|
|
|
cls,
|
|
|
|
cls,
|
|
|
|
session: StreamSession,
|
|
|
|
session: StreamSession,
|
|
|
|
headers: dict,
|
|
|
|
headers: dict,
|
|
|
@ -150,7 +148,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
return ImageResponse(download_url, image_data["file_name"], image_data)
|
|
|
|
return ImageResponse(download_url, image_data["file_name"], image_data)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def _get_default_model(cls, session: StreamSession, headers: dict):
|
|
|
|
async def get_default_model(cls, session: StreamSession, headers: dict):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Get the default model name from the service
|
|
|
|
Get the default model name from the service
|
|
|
|
|
|
|
|
|
|
|
@ -161,20 +159,17 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
The default model name as a string
|
|
|
|
The default model name as a string
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Check the cache for the default model
|
|
|
|
if not cls.default_model:
|
|
|
|
if cls._default_model:
|
|
|
|
|
|
|
|
return cls._default_model
|
|
|
|
|
|
|
|
# Get the models data from the service
|
|
|
|
|
|
|
|
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
|
|
|
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
|
|
|
data = await response.json()
|
|
|
|
data = await response.json()
|
|
|
|
if "categories" in data:
|
|
|
|
if "categories" in data:
|
|
|
|
cls._default_model = data["categories"][-1]["default_model"]
|
|
|
|
cls.default_model = data["categories"][-1]["default_model"]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Response: {data}")
|
|
|
|
raise RuntimeError(f"Response: {data}")
|
|
|
|
return cls._default_model
|
|
|
|
return cls.default_model
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
|
|
|
def create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Create a list of messages for the user input
|
|
|
|
Create a list of messages for the user input
|
|
|
|
|
|
|
|
|
|
|
@ -222,7 +217,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
return messages
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
|
|
|
|
async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Retrieves the image response based on the message content.
|
|
|
|
Retrieves the image response based on the message content.
|
|
|
|
|
|
|
|
|
|
|
@ -257,7 +252,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
raise RuntimeError(f"Error in downloading image: {e}")
|
|
|
|
raise RuntimeError(f"Error in downloading image: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
|
|
|
|
async def delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Deletes a conversation by setting its visibility to False.
|
|
|
|
Deletes a conversation by setting its visibility to False.
|
|
|
|
|
|
|
|
|
|
|
@ -322,7 +317,6 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
Raises:
|
|
|
|
Raises:
|
|
|
|
RuntimeError: If an error occurs during processing.
|
|
|
|
RuntimeError: If an error occurs during processing.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model = MODELS.get(model, model)
|
|
|
|
|
|
|
|
if not parent_id:
|
|
|
|
if not parent_id:
|
|
|
|
parent_id = str(uuid.uuid4())
|
|
|
|
parent_id = str(uuid.uuid4())
|
|
|
|
if not cookies:
|
|
|
|
if not cookies:
|
|
|
@ -333,7 +327,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
login_url = os.environ.get("G4F_LOGIN_URL")
|
|
|
|
login_url = os.environ.get("G4F_LOGIN_URL")
|
|
|
|
if login_url:
|
|
|
|
if login_url:
|
|
|
|
yield f"Please login: [ChatGPT]({login_url})\n\n"
|
|
|
|
yield f"Please login: [ChatGPT]({login_url})\n\n"
|
|
|
|
access_token, cookies = cls._browse_access_token(proxy)
|
|
|
|
access_token, cookies = cls.browse_access_token(proxy)
|
|
|
|
cls._cookies = cookies
|
|
|
|
cls._cookies = cookies
|
|
|
|
|
|
|
|
|
|
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
@ -344,12 +338,10 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
timeout=timeout,
|
|
|
|
timeout=timeout,
|
|
|
|
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
|
|
|
|
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
|
|
|
|
) as session:
|
|
|
|
) as session:
|
|
|
|
if not model:
|
|
|
|
|
|
|
|
model = await cls._get_default_model(session, headers)
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
image_response = None
|
|
|
|
image_response = None
|
|
|
|
if image:
|
|
|
|
if image:
|
|
|
|
image_response = await cls._upload_image(session, headers, image)
|
|
|
|
image_response = await cls.upload_image(session, headers, image)
|
|
|
|
yield image_response
|
|
|
|
yield image_response
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
yield e
|
|
|
|
yield e
|
|
|
@ -357,15 +349,15 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
while not end_turn.is_end:
|
|
|
|
while not end_turn.is_end:
|
|
|
|
data = {
|
|
|
|
data = {
|
|
|
|
"action": action,
|
|
|
|
"action": action,
|
|
|
|
"arkose_token": await cls._get_arkose_token(session),
|
|
|
|
"arkose_token": await cls.get_arkose_token(session),
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
"parent_message_id": parent_id,
|
|
|
|
"parent_message_id": parent_id,
|
|
|
|
"model": model,
|
|
|
|
"model": cls.get_model(model or await cls.get_default_model(session, headers)),
|
|
|
|
"history_and_training_disabled": history_disabled and not auto_continue,
|
|
|
|
"history_and_training_disabled": history_disabled and not auto_continue,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if action != "continue":
|
|
|
|
if action != "continue":
|
|
|
|
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
|
|
|
|
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
|
|
|
|
data["messages"] = cls._create_messages(prompt, image_response)
|
|
|
|
data["messages"] = cls.create_messages(prompt, image_response)
|
|
|
|
async with session.post(
|
|
|
|
async with session.post(
|
|
|
|
f"{cls.url}/backend-api/conversation",
|
|
|
|
f"{cls.url}/backend-api/conversation",
|
|
|
|
json=data,
|
|
|
|
json=data,
|
|
|
@ -391,7 +383,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
if "message_type" not in line["message"]["metadata"]:
|
|
|
|
if "message_type" not in line["message"]["metadata"]:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
image_response = await cls._get_generated_image(session, headers, line)
|
|
|
|
image_response = await cls.get_generated_image(session, headers, line)
|
|
|
|
if image_response:
|
|
|
|
if image_response:
|
|
|
|
yield image_response
|
|
|
|
yield image_response
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
@ -422,10 +414,10 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
action = "continue"
|
|
|
|
action = "continue"
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
if history_disabled and auto_continue:
|
|
|
|
if history_disabled and auto_continue:
|
|
|
|
await cls._delete_conversation(session, headers, conversation_id)
|
|
|
|
await cls.delete_conversation(session, headers, conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
|
|
|
|
def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Browse to obtain an access token.
|
|
|
|
Browse to obtain an access token.
|
|
|
|
|
|
|
|
|
|
|
@ -452,7 +444,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
|
|
|
driver.quit()
|
|
|
|
driver.quit()
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def _get_arkose_token(cls, session: StreamSession) -> str:
|
|
|
|
async def get_arkose_token(cls, session: StreamSession) -> str:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Obtain an Arkose token for the session.
|
|
|
|
Obtain an Arkose token for the session.
|
|
|
|
|
|
|
|
|
|
|
|