Merge pull request #1414 from hlohaus/lia

Patch event loop on win, Check event loop closed
This commit is contained in:
H Lohaus 2024-01-01 02:10:53 +01:00 committed by GitHub
commit e64a003323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 28 deletions

View File

@ -13,6 +13,13 @@ if sys.version_info < (3, 10):
else: else:
from types import NoneType from types import NoneType
# Change event loop policy on windows for curl_cffi
if sys.platform == 'win32':
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class BaseProvider(ABC): class BaseProvider(ABC):
url: str url: str
working: bool = False working: bool = False

View File

@ -7,9 +7,9 @@ import random
import string import string
import secrets import secrets
import os import os
from os import path from os import path
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from platformdirs import user_config_dir from platformdirs import user_config_dir
from browser_cookie3 import ( from browser_cookie3 import (
chrome, chrome,
chromium, chromium,
@ -25,37 +25,33 @@ from browser_cookie3 import (
from ..typing import Dict, Messages from ..typing import Dict, Messages
from .. import debug from .. import debug
# Change event loop policy on windows
if sys.platform == 'win32':
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Local Cookie Storage # Local Cookie Storage
_cookies: Dict[str, Dict[str, str]] = {} _cookies: Dict[str, Dict[str, str]] = {}
# If event loop is already running, handle nested event loops # If loop closed or not set, create new event loop.
# If event loop is already running, handle nested event loops.
# If "nest_asyncio" is installed, patch the event loop. # If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop: def get_event_loop() -> AbstractEventLoop:
try: try:
asyncio.get_running_loop() loop = asyncio.get_event_loop()
loop._check_closed()
except RuntimeError: except RuntimeError:
try: loop = asyncio.new_event_loop()
return asyncio.get_event_loop() asyncio.set_event_loop(loop)
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
return asyncio.get_event_loop()
try: try:
event_loop = asyncio.get_event_loop() # Is running event loop
if not hasattr(event_loop.__class__, "_nest_patched"): asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio import nest_asyncio
nest_asyncio.apply(event_loop) nest_asyncio.apply(loop)
return event_loop except RuntimeError:
# No running event loop
pass
except ImportError: except ImportError:
raise RuntimeError( raise RuntimeError(
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
) )
return loop
def init_cookies(): def init_cookies():
urls = [ urls = [

View File

@ -7,9 +7,9 @@ from async_property import async_cached_property
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from ..base_provider import AsyncGeneratorProvider from ..base_provider import AsyncGeneratorProvider
from ..helper import get_event_loop, format_prompt from ..helper import get_event_loop, format_prompt, get_cookies
from ...webdriver import get_browser from ...webdriver import get_browser
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ...requests import StreamSession from ...requests import StreamSession
@ -27,7 +27,7 @@ class OpenaiChat(AsyncGeneratorProvider):
needs_auth = True needs_auth = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
_access_token: str = None _cookies: dict = {}
@classmethod @classmethod
async def create( async def create(
@ -72,6 +72,7 @@ class OpenaiChat(AsyncGeneratorProvider):
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
access_token: str = None, access_token: str = None,
cookies: dict = None,
auto_continue: bool = False, auto_continue: bool = False,
history_disabled: bool = True, history_disabled: bool = True,
action: str = "next", action: str = "next",
@ -86,13 +87,18 @@ class OpenaiChat(AsyncGeneratorProvider):
raise ValueError(f"Model are not supported: {model}") raise ValueError(f"Model are not supported: {model}")
if not parent_id: if not parent_id:
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
if not cookies:
cookies = cls._cookies
if not access_token: if not access_token:
access_token = cls._access_token if not cookies:
cls._cookies = cookies = get_cookies("chat.openai.com")
if "access_token" in cookies:
access_token = cookies["access_token"]
if not access_token: if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n" yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token = cls._access_token = await cls.browse_access_token(proxy) cls._cookies["access_token"] = access_token = await cls.browse_access_token(proxy)
headers = { headers = {
"Accept": "text/event-stream", "Accept": "text/event-stream",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
@ -101,7 +107,8 @@ class OpenaiChat(AsyncGeneratorProvider):
proxies={"https": proxy}, proxies={"https": proxy},
impersonate="chrome110", impersonate="chrome110",
headers=headers, headers=headers,
timeout=timeout timeout=timeout,
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session: ) as session:
end_turn = EndTurn() end_turn = EndTurn()
while not end_turn.is_end: while not end_turn.is_end:
@ -170,7 +177,12 @@ class OpenaiChat(AsyncGeneratorProvider):
WebDriverWait(driver, 1200).until( WebDriverWait(driver, 1200).until(
EC.presence_of_element_located((By.ID, "prompt-textarea")) EC.presence_of_element_located((By.ID, "prompt-textarea"))
) )
javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']" javascript = """
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
return access_token;
"""
return driver.execute_script(javascript) return driver.execute_script(javascript)
finally: finally:
driver.quit() driver.quit()