Merge pull request #1269 from hlohaus/any

Add Response Handler to OpenaiChat
pull/1272/head
Tekky 8 months ago committed by GitHub
commit eeb26036ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@ import random
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider from .base_provider import BaseProvider
from .helper import WebDriver, format_prompt, get_browser, get_random_string from .helper import WebDriver, WebDriverSession, format_prompt, get_random_string
from .. import debug from .. import debug
class AItianhuSpace(BaseProvider): class AItianhuSpace(BaseProvider):
@ -24,7 +24,7 @@ class AItianhuSpace(BaseProvider):
domain: str = None, domain: str = None,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
browser: WebDriver = None, web_driver: WebDriver = None,
headless: bool = True, headless: bool = True,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
@ -38,8 +38,8 @@ class AItianhuSpace(BaseProvider):
print(f"AItianhuSpace | using domain: {domain}") print(f"AItianhuSpace | using domain: {domain}")
url = f"https://{domain}" url = f"https://{domain}"
prompt = format_prompt(messages) prompt = format_prompt(messages)
driver = browser if browser else get_browser("", headless, proxy)
with WebDriverSession(web_driver, "", headless=headless, proxy=proxy) as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
@ -67,7 +67,6 @@ document.getElementById('sheet').addEventListener('click', () => {{
# Wait for page load # Wait for page load
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea.n-input__textarea-el"))) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea.n-input__textarea-el")))
try:
# Register hook in XMLHttpRequest # Register hook in XMLHttpRequest
script = """ script = """
const _http_request_open = XMLHttpRequest.prototype.open; const _http_request_open = XMLHttpRequest.prototype.open;
@ -115,8 +114,3 @@ return "";
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -4,7 +4,7 @@ import time, json
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider from .base_provider import BaseProvider
from .helper import WebDriver, format_prompt, get_browser from .helper import WebDriver, WebDriverSession, format_prompt
class MyShell(BaseProvider): class MyShell(BaseProvider):
url = "https://app.myshell.ai/chat" url = "https://app.myshell.ai/chat"
@ -20,22 +20,27 @@ class MyShell(BaseProvider):
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
browser: WebDriver = None, web_driver: WebDriver = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
driver = browser if browser else get_browser("", False, proxy) with WebDriverSession(web_driver, "", proxy=proxy) as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
driver.get(cls.url) driver.get(cls.url)
try:
# Wait for page load and cloudflare validation # Wait for page load and cloudflare validation
WebDriverWait(driver, timeout).until( WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)")) EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
) )
# Send request with message # Send request with message
data = {
"botId": "4738",
"conversation_scenario": 3,
"message": format_prompt(messages),
"messageType": 1
}
script = """ script = """
response = await fetch("https://api.myshell.ai/v1/bot/chat/send_message", { response = await fetch("https://api.myshell.ai/v1/bot/chat/send_message", {
"headers": { "headers": {
@ -49,12 +54,6 @@ response = await fetch("https://api.myshell.ai/v1/bot/chat/send_message", {
}) })
window.reader = response.body.getReader(); window.reader = response.body.getReader();
""" """
data = {
"botId": "4738",
"conversation_scenario": 3,
"message": format_prompt(messages),
"messageType": 1
}
driver.execute_script(script.replace("{body}", json.dumps(data))) driver.execute_script(script.replace("{body}", json.dumps(data)))
script = """ script = """
chunk = await window.reader.read(); chunk = await window.reader.read();
@ -81,8 +80,3 @@ return content;
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -4,7 +4,7 @@ import time
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider from .base_provider import BaseProvider
from .helper import WebDriver, format_prompt, get_browser from .helper import WebDriver, WebDriverSession, format_prompt
class PerplexityAi(BaseProvider): class PerplexityAi(BaseProvider):
url = "https://www.perplexity.ai" url = "https://www.perplexity.ai"
@ -20,12 +20,12 @@ class PerplexityAi(BaseProvider):
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
browser: WebDriver = None, web_driver: WebDriver = None,
virtual_display: bool = True,
copilot: bool = False, copilot: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
driver = browser if browser else get_browser("", False, proxy) with WebDriverSession(web_driver, "", virtual_display=virtual_display, proxy=proxy) as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
@ -82,7 +82,6 @@ WebSocket.prototype.send = function(...args) {
driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt) driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt)
driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER) driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER)
try:
# Stream response # Stream response
script = """ script = """
if(window._message && window._message != window._last_message) { if(window._message && window._message != window._last_message) {
@ -105,8 +104,3 @@ if(window._message && window._message != window._last_message) {
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -5,7 +5,7 @@ from urllib.parse import quote
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider from .base_provider import BaseProvider
from .helper import WebDriver, format_prompt, get_browser from .helper import WebDriver, WebDriverSession, format_prompt
class Phind(BaseProvider): class Phind(BaseProvider):
url = "https://www.phind.com" url = "https://www.phind.com"
@ -21,13 +21,11 @@ class Phind(BaseProvider):
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
browser: WebDriver = None, web_driver: WebDriver = None,
creative_mode: bool = None, creative_mode: bool = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
try: with WebDriverSession(web_driver, "", proxy=proxy) as driver:
driver = browser if browser else get_browser("", False, proxy)
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
@ -103,8 +101,3 @@ if(window._reader) {
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -4,7 +4,7 @@ import time, json, time
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider from .base_provider import BaseProvider
from .helper import WebDriver, get_browser from .helper import WebDriver, WebDriverSession
class TalkAi(BaseProvider): class TalkAi(BaseProvider):
url = "https://talkai.info" url = "https://talkai.info"
@ -19,16 +19,14 @@ class TalkAi(BaseProvider):
messages: Messages, messages: Messages,
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
browser: WebDriver = None, web_driver: WebDriver = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
driver = browser if browser else get_browser("", False, proxy) with WebDriverSession(web_driver, "", virtual_display=True, proxy=proxy) as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
try:
driver.get(f"{cls.url}/chat/") driver.get(f"{cls.url}/chat/")
# Wait for page load # Wait for page load
@ -87,8 +85,3 @@ return content;
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -6,6 +6,7 @@ import webbrowser
import random import random
import string import string
import secrets import secrets
import time
from os import path from os import path
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from platformdirs import user_config_dir from platformdirs import user_config_dir
@ -34,6 +35,10 @@ except ImportError:
class ChromeOptions(): class ChromeOptions():
def add_argument(): def add_argument():
pass pass
try:
from pyvirtualdisplay import Display
except ImportError:
pass
from ..typing import Dict, Messages, Union, Tuple from ..typing import Dict, Messages, Union, Tuple
from .. import debug from .. import debug
@ -144,6 +149,53 @@ def get_browser(
options.add_argument(f'--proxy-server={proxy}') options.add_argument(f'--proxy-server={proxy}')
return Chrome(options=options, user_data_dir=user_data_dir, headless=headless) return Chrome(options=options, user_data_dir=user_data_dir, headless=headless)
class WebDriverSession():
def __init__(
self,
web_driver: WebDriver = None,
user_data_dir: str = None,
headless: bool = False,
virtual_display: bool = False,
proxy: str = None,
options: ChromeOptions = None
):
self.web_driver = web_driver
self.user_data_dir = user_data_dir
self.headless = headless
self.virtual_display = virtual_display
self.proxy = proxy
self.options = options
def reopen(
self,
user_data_dir: str = None,
headless: bool = False,
virtual_display: bool = False
) -> WebDriver:
if user_data_dir == None:
user_data_dir = self.user_data_dir
self.default_driver.quit()
if not virtual_display and self.virtual_display:
self.virtual_display.stop()
self.default_driver = get_browser(user_data_dir, headless, self.proxy)
return self.default_driver
def __enter__(self) -> WebDriver:
if self.web_driver:
return self.web_driver
if self.virtual_display == True:
self.virtual_display = Display(size=(1920,1080))
self.virtual_display.start()
self.default_driver = get_browser(self.user_data_dir, self.headless, self.proxy, self.options)
return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb):
if self.default_driver:
self.default_driver.close()
time.sleep(0.1)
self.default_driver.quit()
if self.virtual_display:
self.virtual_display.stop()
def get_random_string(length: int = 10) -> str: def get_random_string(length: int = 10) -> str:
return ''.join( return ''.join(

@ -4,7 +4,7 @@ import time
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
from ..helper import WebDriver, format_prompt, get_browser from ..helper import WebDriver, WebDriverSession, format_prompt
class Bard(BaseProvider): class Bard(BaseProvider):
url = "https://bard.google.com" url = "https://bard.google.com"
@ -18,14 +18,14 @@ class Bard(BaseProvider):
messages: Messages, messages: Messages,
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
browser: WebDriver = None, web_driver: WebDriver = None,
user_data_dir: str = None, user_data_dir: str = None,
headless: bool = True, headless: bool = True,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
prompt = format_prompt(messages) prompt = format_prompt(messages)
driver = browser if browser else get_browser(user_data_dir, headless, proxy) session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy)
with session as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
@ -36,16 +36,14 @@ class Bard(BaseProvider):
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")))
except: except:
# Reopen browser for login # Reopen browser for login
if not browser: if not web_driver:
driver.quit() driver = session.reopen(headless=False)
driver = get_browser(None, False, proxy)
driver.get(f"{cls.url}/chat") driver.get(f"{cls.url}/chat")
wait = WebDriverWait(driver, 240) wait = WebDriverWait(driver, 240)
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")))
else: else:
raise RuntimeError("Prompt textarea not found. You may not be logged in.") raise RuntimeError("Prompt textarea not found. You may not be logged in.")
try:
# Add hook in XMLHttpRequest # Add hook in XMLHttpRequest
script = """ script = """
const _http_request_open = XMLHttpRequest.prototype.open; const _http_request_open = XMLHttpRequest.prototype.open;
@ -73,8 +71,3 @@ XMLHttpRequest.prototype.open = function(method, url) {
return return
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -1,20 +1,64 @@
from __future__ import annotations from __future__ import annotations
import uuid, json, time, asyncio import uuid, json, asyncio
from py_arkose_generator.arkose import get_values_for_request from py_arkose_generator.arkose import get_values_for_request
from asyncstdlib.itertools import tee
from async_property import async_cached_property
from ..base_provider import AsyncGeneratorProvider from ..base_provider import AsyncGeneratorProvider
from ..helper import get_browser, get_cookies, format_prompt, get_event_loop from ..helper import get_browser, get_event_loop
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ...requests import StreamSession from ...requests import StreamSession
from ... import debug
models = {
"gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4",
"gpt-4-gizmo": "gpt-4-gizmo"
}
class OpenaiChat(AsyncGeneratorProvider): class OpenaiChat(AsyncGeneratorProvider):
url = "https://chat.openai.com" url = "https://chat.openai.com"
needs_auth = True
working = True working = True
needs_auth = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
_access_token = None supports_gpt_4 = True
_access_token: str = None
@classmethod
async def create(
cls,
prompt: str = None,
model: str = "",
messages: Messages = [],
history_disabled: bool = False,
action: str = "next",
conversation_id: str = None,
parent_id: str = None,
**kwargs
) -> Response:
if prompt:
messages.append({"role": "user", "content": prompt})
generator = cls.create_async_generator(
model,
messages,
history_disabled=history_disabled,
action=action,
conversation_id=conversation_id,
parent_id=parent_id,
response_fields=True,
**kwargs
)
fields: ResponseFields = await anext(generator)
if "access_token" not in kwargs:
kwargs["access_token"] = cls._access_token
return Response(
generator,
fields,
action,
messages,
kwargs
)
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -25,50 +69,56 @@ class OpenaiChat(AsyncGeneratorProvider):
timeout: int = 120, timeout: int = 120,
access_token: str = None, access_token: str = None,
auto_continue: bool = False, auto_continue: bool = False,
cookies: dict = None, history_disabled: bool = True,
action: str = "next",
conversation_id: str = None,
parent_id: str = None,
response_fields: bool = False,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
proxies = {"https": proxy} if not model:
model = "gpt-3.5"
elif model not in models:
raise ValueError(f"Model are not supported: {model}")
if not parent_id:
parent_id = str(uuid.uuid4())
if not access_token: if not access_token:
access_token = await cls.get_access_token(cookies, proxies) access_token = await cls.get_access_token(proxy)
headers = { headers = {
"Accept": "text/event-stream", "Accept": "text/event-stream",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Cookie": 'intercom-device-id-dgkjq2bp=0f047573-a750-46c8-be62-6d54b56e7bf0; ajs_user_id=user-iv3vxisaoNodwWpxmNpMfekH; ajs_anonymous_id=fd91be0b-0251-4222-ac1e-84b1071e9ec1; __Host-next-auth.csrf-token=d2b5f67d56f7dd6a0a42ae4becf2d1a6577b820a5edc88ab2018a59b9b506886%7Ce5c33eecc460988a137cbc72d90ee18f1b4e2f672104f368046df58e364376ac; _cfuvid=gt_mA.q6rue1.7d2.AR0KHpbVBS98i_ppfi.amj2._o-1700353424353-0-604800000; cf_clearance=GkHCfPSFU.NXGcHROoe4FantnqmnNcluhTNHz13Tk.M-1700353425-0-1-dfe77f81.816e9bc2.714615da-0.2.1700353425; __Secure-next-auth.callback-url=https%3A%2F%2Fchat.openai.com; intercom-session-dgkjq2bp=UWdrS1hHazk5VXN1c0V5Q1F0VXdCQmsyTU9pVjJMUkNpWnFnU3dKWmtIdGwxTC9wbjZuMk5hcEc0NWZDOGdndS0tSDNiaDNmMEdIL1RHU1dFWDBwOHFJUT09--f754361b91fddcd23a13b288dcb2bf8c7f509e91; _uasid="Z0FBQUFBQmxXVnV0a3dmVno4czRhcDc2ZVcwaUpSNUdZejlDR25YSk5NYTJQQkpyNmRvOGxjTHMyTlAxWmJhaURrMVhjLXZxQXdZeVpBbU1aczA5WUpHT2dwaS1MOWc4MnhyNWFnbGRzeGdJcGFKT0ZRdnBTMVJHcGV2MGNTSnVQY193c0hqUWIycHhQRVF4dENlZ3phcDdZeHgxdVhoalhrZmtZME9NbWhMQjdVR3Vzc3FRRk0ybjJjNWMwTWtIRjdPb19lUkFtRmV2MDVqd1kwWU11QTYtQkdZenEzVHhLMGplY1hZM3FlYUt1cVZaNWFTRldleEJETzJKQjk1VTJScy1GUnMxUVZWMnVxYklxMjdockVZbkZyd1R4U1RtMnA1ZzlSeXphdmVOVk9xeEdrRkVOSjhwTVd1QzFtQjhBcWdDaE92Q1VlM2pwcjFQTXRuLVJNRVlZSGpIdlZ0aGV3PT0="; _dd_s=rum=0&expire=1700356244884; __Secure-next-auth.session-token=eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..3aK6Fbdy2_8f07bf.8eT2xgonrCnz7ySY6qXFsg3kzL6UQfXKAYaw3tyn-6_X9657zy47k9qGvmi9mF0QKozj5jau3_Ca62AQQ7FmeC6Y2F1urtzqrXqwTTsQ2LuzFPIQkx6KKb2DXc8zW2-oyEzJ_EY5yxfLB2RlRkSh3M7bYNZh4_ltEcfkj38s_kIPGMxv34udtPWGWET99MCjkdwQWXylJag4s0fETA0orsBAKnGCyqAUNJbb_D7BYtGSV-MQ925kZMG6Di_QmfO0HQWURDYjmdRNcuy1PT_xJ1DJko8sjL42i4j3RhkNDkhqCIqyYImz2eHFWHW7rYKxTkrBhlCPMS5hRdcCswD7JYPcSBiwnVRYgyOocFGXoFvQgIZ2FX9NiZ3SMEVM1VwIGSE-qH0H2nMa8_iBvsOgOWJgKjVAvzzyzZvRVDUUHzJrikSFPNONVDU3h-04c1kVL4qIu9DfeTPN7n8AvNmYwMbro0L9-IUAeXNo4-pwF0Kt-AtTsamqWvMqnK4O_YOyLnDDlvkmnOvDC2d5uinwlQIxr6APO6qFfGLlHiLZemKoekxEE1Fx70dl-Ouhk1VIzbF3OC6XNNxeBm9BUYUiHdL0wj2H9rHgX4cz6ZmS_3VTgpD6UJh-evu5KJ2gIvjYmVbyzEN0aPNDxfvBaOm-Ezpy4bUJ2bUrOwNn-0knWkDiTvjYmNhCyefPCtCF6rpKNay8PCw_yh79C4SdEP6Q4V7LI0Tvdi5uz7kLCiBC4AT9L0ao1WDX03mkUOpjvzHDvPLmj8chW3lTVm_kA0eYGQY4wT0jzleWlfV0Q8rB2oYECNLWksA3F1zlGfcl4lQjprvTXRePkvAbMpoJEsZD3Ylq7-foLDLk4-M2LYAFZDs282AY04sFjAjQBxTELFCCuDgTIgTXSIskY_XCxpVXDbdLlbCJY7XVK45ybwtfqwlKRp8Mo0B131uQAFc-migHaUaoGujxJJk21bP8F0OmhNYHBo4FQqE1rQm2JH5bNM7txKeh5KXdJgVUVbRSr7OIp_OF5-Bx_v9eRBGAIDkue26E2-O8Rnrp5zQ5TnvecQLDaUzWavCLPwsZ0_gsOLBxNOmauNYZtF8IElCsQSFDdhoiMxXsYUm4ZYKEAy3GWq8HGTAvBhNkh1hvnI7y-d8-DOaZf_D_D98-olZfm-LUkeosLNpPB9rxYMqViCiW3KrXE9Yx0wlFm5ePKaVvR7Ym_EPhSOhJBKFPCvdTdMZSNPUcW0ZJBVByq0A9sxD51lYq3gaFyqh94S4s_ox182AQ3szGzHkdgLcnQmJG9OYvKxAVcd43eg6_gODAYhx02GjbMw-7JTAhyXSeCrlMteHyOXl8hai-3LilC3PmMzi7Vbu49dhF1s4LcVlUowen5ira44rQQaB26mdaOUoQfodgt66M3RTWGPXyK1Nb72AzSXsCKyaQPbzeb6cN0fdGSdG4ktwvR04eFNEkquo_3aKu2GmUKTD0XcRx9dYrfXjgY-X1DDTVs1YND2gRhdx7FFEeBVjtbj2UqmG3Rvd4IcHGe7OnYWw2MHDcol68SsR1KckXWwWREz7YTGUnDB2M1kx_H4W2mjclytnlHOnYU3RflegRPeSTbdzUZJvGKXCCz45luHkQWN_4DExE76D-9YqbFIz-RY5yL4h-Zs-i2xjm2K-4xCMM9nQIOqhLMqixIZQ2ldDAidKoYtbs5ppzbcBLyrZM96bq9DwRBY3aacqWdlRd-TfX0wv5KO4fo0sSh5FsuhuN0zcEV_NNXgqIEM_p14EcPqgbrAvCBQ8os70TRBQLXiF0EniSofGjxwF8kQvUk3C6Wfc8cTTeN-E6GxCVTn91HBwA1iSEZlRLMVb8_BcRJNqwbgnb_07jR6-eo42u88CR3KQdAWwbQRdMxsURFwZ0ujHXVGG0Ll6qCFBcHXWyDO1x1yHdHnw8_8yF26pnA2iPzrFR-8glMgIA-639sLuGAxjO1_ZuvJ9CAB41Az9S_jaZwaWy215Hk4-BRYD-MKmHtonwo3rrxhE67WJgbbu14efsw5nT6ow961pffgwXov5VA1Rg7nv1E8RvQOx7umWW6o8R4W6L8f2COsmPTXfgwIjoJKkjhUqAQ8ceG7cM0ET-38yaC0ObU8EkXfdGGgxI28qTEZWczG66_iM4hw7QEGCY5Cz2kbO6LETAiw9OsSigtBvDS7f0Ou0bZ41pdK7G3FmvdZAnjWPjObnDF4k4uWfn7mzt0fgj3FyqK20JezRDyGuAbUUhOvtZpc9sJpzxR34eXEZTouuALrHcGuNij4z6rx51FrQsaMtiup8QVrhtZbXtKLMYnWYSbkhuTeN2wY-xV1ZUsQlakIZszzGF7kuIG87KKWMpuPMvbXjz6Pp_gWJiIC6aQuk8xl5g0iBPycf_6Q-MtpuYxzNE2TpI1RyR9mHeXmteoRzrFiWp7yEC-QGNFyAJgxTqxM3CjHh1Jt6IddOsmn89rUo1dZM2Smijv_fbIv3avXLkIPX1KZjILeJCtpU0wAdsihDaRiRgDdx8fG__F8zuP0n7ziHas73cwrfg-Ujr6DhC0gTNxyd9dDA_oho9N7CQcy6EFmfNF2te7zpLony0859jtRv2t1TnpzAa1VvMK4u6mXuJ2XDo04_6GzLO3aPHinMdl1BcIAWnqAqWAu3euGFLTHOhXlfijut9N1OCifd_zWjhVtzlR39uFeCQBU5DyQArzQurdoMx8U1ETsnWgElxGSStRW-YQoPsAJ87eg9trqKspFpTVlAVN3t1GtoEAEhcwhe81SDssLmKGLc.7PqS6jRGTIfgTPlO7Ognvg; __cf_bm=VMWoAKEB45hQSwxXtnYXcurPaGZDJS4dMi6dIMFLwdw-1700355394-0-ATVsbq97iCaTaJbtYr8vtg1Zlbs3nLrJLKVBHYa2Jn7hhkGclqAy8Gbyn5ePEhDRqj93MsQmtayfYLqY5n4WiLY=; __cflb=0H28vVfF4aAyg2hkHFH9CkdHRXPsfCUf6VpYf2kz3RX'
} }
messages = [
{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": {"content_type": "text", "parts": [format_prompt(messages)]},
},
]
message_id = str(uuid.uuid4())
data = {
"action": "next",
"arkose_token": await get_arkose_token(proxy),
"messages": messages,
"conversation_id": None,
"parent_message_id": message_id,
"model": "text-davinci-002-render-sha",
"history_and_training_disabled": not auto_continue,
}
conversation_id = None
end_turn = False
while not end_turn:
if not auto_continue:
end_turn = True
async with StreamSession( async with StreamSession(
proxies=proxies, proxies={"https": proxy},
impersonate="chrome110",
headers=headers, headers=headers,
impersonate="chrome107",
timeout=timeout timeout=timeout
) as session: ) as session:
data = {
"action": action,
"arkose_token": await get_arkose_token(proxy, timeout),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"model": models[model],
"history_and_training_disabled": history_disabled and not auto_continue,
}
if action != "continue":
data["messages"] = [{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": {"content_type": "text", "parts": [messages[-1]["content"]]},
}]
first = True
end_turn = EndTurn()
while first or auto_continue and not end_turn.is_end:
first = False
async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response: async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response:
try: try:
response.raise_for_status() response.raise_for_status()
except: except:
raise RuntimeError(f"Response: {await response.text()}") raise RuntimeError(f"Error {response.status_code}: {await response.text()}")
last_message = "" last_message = 0
async for line in response.iter_lines(): async for line in response.iter_lines():
if line.startswith(b"data: "): if line.startswith(b"data: "):
line = line[6:] line = line[6:]
@ -82,50 +132,52 @@ class OpenaiChat(AsyncGeneratorProvider):
continue continue
if "error" in line and line["error"]: if "error" in line and line["error"]:
raise RuntimeError(line["error"]) raise RuntimeError(line["error"])
end_turn = line["message"]["end_turn"]
message_id = line["message"]["id"]
if line["conversation_id"]:
conversation_id = line["conversation_id"]
if "message_type" not in line["message"]["metadata"]: if "message_type" not in line["message"]["metadata"]:
continue continue
if line["message"]["metadata"]["message_type"] in ("next", "continue"): if line["message"]["author"]["role"] != "assistant":
continue
if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
conversation_id = line["conversation_id"]
parent_id = line["message"]["id"]
if response_fields:
response_fields = False
yield ResponseFields(conversation_id, parent_id, end_turn)
new_message = line["message"]["content"]["parts"][0] new_message = line["message"]["content"]["parts"][0]
yield new_message[len(last_message):] yield new_message[last_message:]
last_message = new_message last_message = len(new_message)
if end_turn: if "finish_details" in line["message"]["metadata"]:
return if line["message"]["metadata"]["finish_details"]["type"] == "max_tokens":
end_turn.end()
data = { data = {
"action": "continue", "action": "continue",
"arkose_token": await get_arkose_token(proxy), "arkose_token": await get_arkose_token(proxy, timeout),
"conversation_id": conversation_id, "conversation_id": conversation_id,
"parent_message_id": message_id, "parent_message_id": parent_id,
"model": "text-davinci-002-render-sha", "model": models[model],
"history_and_training_disabled": False, "history_and_training_disabled": False,
} }
await asyncio.sleep(5) await asyncio.sleep(5)
@classmethod @classmethod
async def browse_access_token(cls) -> str: async def browse_access_token(cls, proxy: str = None) -> str:
def browse() -> str: def browse() -> str:
try: try:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
driver = get_browser() driver = get_browser("~/openai", proxy=proxy)
except ImportError: except ImportError:
return return
driver.get(f"{cls.url}/")
try: try:
driver.get(f"{cls.url}/")
WebDriverWait(driver, 1200).until( WebDriverWait(driver, 1200).until(
EC.presence_of_element_located((By.ID, "prompt-textarea")) EC.presence_of_element_located((By.ID, "prompt-textarea"))
) )
javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']" javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']"
return driver.execute_script(javascript) return driver.execute_script(javascript)
finally: finally:
driver.close()
time.sleep(0.1)
driver.quit() driver.quit()
loop = get_event_loop() loop = get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
@ -134,22 +186,9 @@ class OpenaiChat(AsyncGeneratorProvider):
) )
@classmethod @classmethod
async def fetch_access_token(cls, cookies: dict, proxies: dict = None) -> str: async def get_access_token(cls, proxy: str = None) -> str:
async with StreamSession(proxies=proxies, cookies=cookies, impersonate="chrome107") as session:
async with session.get(f"{cls.url}/api/auth/session") as response:
response.raise_for_status()
auth = await response.json()
if "accessToken" in auth:
return auth["accessToken"]
@classmethod
async def get_access_token(cls, cookies: dict = None, proxies: dict = None) -> str:
if not cls._access_token: if not cls._access_token:
cookies = cookies if cookies else get_cookies("chat.openai.com") cls._access_token = await cls.browse_access_token(proxy)
if cookies:
cls._access_token = await cls.fetch_access_token(cookies, proxies)
if not cls._access_token:
cls._access_token = await cls.browse_access_token()
if not cls._access_token: if not cls._access_token:
raise RuntimeError("Read access token failed") raise RuntimeError("Read access token failed")
return cls._access_token return cls._access_token
@ -163,12 +202,11 @@ class OpenaiChat(AsyncGeneratorProvider):
("stream", "bool"), ("stream", "bool"),
("proxy", "str"), ("proxy", "str"),
("access_token", "str"), ("access_token", "str"),
("cookies", "dict[str, str]")
] ]
param = ", ".join([": ".join(p) for p in params]) param = ", ".join([": ".join(p) for p in params])
return f"g4f.provider.{cls.__name__} supports: ({param})" return f"g4f.provider.{cls.__name__} supports: ({param})"
async def get_arkose_token(proxy: str = None) -> str: async def get_arkose_token(proxy: str = None, timeout: int = None) -> str:
config = { config = {
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC", "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
"surl": "https://tcr9i.chat.openai.com", "surl": "https://tcr9i.chat.openai.com",
@ -181,6 +219,7 @@ async def get_arkose_token(proxy: str = None) -> str:
async with StreamSession( async with StreamSession(
proxies={"https": proxy}, proxies={"https": proxy},
impersonate="chrome107", impersonate="chrome107",
timeout=timeout
) as session: ) as session:
async with session.post(**args_for_request) as response: async with session.post(**args_for_request) as response:
response.raise_for_status() response.raise_for_status()
@ -188,3 +227,90 @@ async def get_arkose_token(proxy: str = None) -> str:
if "token" in decoded_json: if "token" in decoded_json:
return decoded_json["token"] return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}") raise RuntimeError(f"Response: {decoded_json}")
class EndTurn():
def __init__(self):
self.is_end = False
def end(self):
self.is_end = True
class ResponseFields():
def __init__(
self,
conversation_id: str,
message_id: str,
end_turn: EndTurn
):
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn
class Response():
def __init__(
self,
generator: AsyncResult,
fields: ResponseFields,
action: str,
messages: Messages,
options: dict
):
self.aiter, self.copy = tee(generator)
self.fields = fields
self.action = action
self._messages = messages
self._options = options
def __aiter__(self):
return self.aiter
@async_cached_property
async def message(self) -> str:
return "".join([chunk async for chunk in self.copy])
async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
**self._options,
prompt=prompt,
messages=await self.messages,
action="next",
conversation_id=self.fields.conversation_id,
parent_id=self.fields.message_id,
**kwargs
)
async def do_continue(self, **kwargs) -> Response:
if self.end_turn:
raise RuntimeError("Can't continue message. Message already finished.")
return await OpenaiChat.create(
**self._options,
messages=await self.messages,
action="continue",
conversation_id=self.fields.conversation_id,
parent_id=self.fields.message_id,
**kwargs
)
async def variant(self, **kwargs) -> Response:
if self.action != "next":
raise RuntimeError("Can't create variant with continue or variant request.")
return await OpenaiChat.create(
**self._options,
messages=self._messages,
action="variant",
conversation_id=self.fields.conversation_id,
parent_id=self.fields.message_id,
**kwargs
)
@async_cached_property
async def messages(self):
messages = self._messages
messages.append({
"role": "assistant", "content": await self.message
})
return messages
@property
def end_turn(self):
return self.fields._end_turn.is_end

@ -4,7 +4,7 @@ import time
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
from ..helper import WebDriver, format_prompt, get_browser from ..helper import WebDriver, WebDriverSession, format_prompt
models = { models = {
"meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"}, "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
@ -33,7 +33,7 @@ class Poe(BaseProvider):
messages: Messages, messages: Messages,
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
browser: WebDriver = None, web_driver: WebDriver = None,
user_data_dir: str = None, user_data_dir: str = None,
headless: bool = True, headless: bool = True,
**kwargs **kwargs
@ -43,9 +43,15 @@ class Poe(BaseProvider):
elif model not in models: elif model not in models:
raise ValueError(f"Model are not supported: {model}") raise ValueError(f"Model are not supported: {model}")
prompt = format_prompt(messages) prompt = format_prompt(messages)
driver = browser if browser else get_browser(user_data_dir, headless, proxy)
script = """ session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy)
with session as driver:
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
"source": """
window._message = window._last_message = ""; window._message = window._last_message = "";
window._message_finished = false; window._message_finished = false;
class ProxiedWebSocket extends WebSocket { class ProxiedWebSocket extends WebSocket {
@ -66,23 +72,16 @@ class ProxiedWebSocket extends WebSocket {
} }
window.WebSocket = ProxiedWebSocket; window.WebSocket = ProxiedWebSocket;
""" """
driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
"source": script
}) })
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
try: try:
driver.get(f"{cls.url}/{models[model]['name']}") driver.get(f"{cls.url}/{models[model]['name']}")
wait = WebDriverWait(driver, 10 if headless else 240) wait = WebDriverWait(driver, 10 if headless else 240)
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
except: except:
# Reopen browser for login # Reopen browser for login
if not browser: if not web_driver:
driver.quit() driver = session.reopen(headless=False)
driver = get_browser(None, False, proxy)
driver.get(f"{cls.url}/{models[model]['name']}") driver.get(f"{cls.url}/{models[model]['name']}")
wait = WebDriverWait(driver, 240) wait = WebDriverWait(driver, 240)
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
@ -92,7 +91,6 @@ window.WebSocket = ProxiedWebSocket;
driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt)
driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
try:
script = """ script = """
if(window._message && window._message != window._last_message) { if(window._message && window._message != window._last_message) {
try { try {
@ -114,8 +112,3 @@ if(window._message && window._message != window._last_message) {
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -4,7 +4,7 @@ import time
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
from ..helper import WebDriver, format_prompt, get_browser from ..helper import WebDriver, WebDriverSession, format_prompt
models = { models = {
"theb-ai": "TheB.AI", "theb-ai": "TheB.AI",
@ -44,26 +44,60 @@ class Theb(BaseProvider):
messages: Messages, messages: Messages,
stream: bool, stream: bool,
proxy: str = None, proxy: str = None,
browser: WebDriver = None, web_driver: WebDriver = None,
headless: bool = True, virtual_display: bool = True,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
if model in models: if model in models:
model = models[model] model = models[model]
prompt = format_prompt(messages) prompt = format_prompt(messages)
driver = browser if browser else get_browser(None, headless, proxy) web_session = WebDriverSession(web_driver, virtual_display=virtual_display, proxy=proxy)
with web_session as driver:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
# Register fetch hook
script = """
window._fetch = window.fetch;
window.fetch = (url, options) => {
// Call parent fetch method
const result = window._fetch(url, options);
if (!url.startsWith("/api/conversation")) {
return result;
}
// Load response reader
result.then((response) => {
if (!response.body.locked) {
window._reader = response.body.getReader();
}
});
// Return dummy response
return new Promise((resolve, reject) => {
resolve(new Response(new ReadableStream()))
});
}
window._last_message = "";
"""
driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
"source": script
})
try: try:
driver.get(f"{cls.url}/home") driver.get(f"{cls.url}/home")
wait = WebDriverWait(driver, 10 if headless else 240) wait = WebDriverWait(driver, 5)
wait.until(EC.visibility_of_element_located((By.TAG_NAME, "body"))) wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
time.sleep(0.1) except:
driver = web_session.reopen()
driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
"source": script
})
driver.get(f"{cls.url}/home")
wait = WebDriverWait(driver, 240)
wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
time.sleep(200)
try: try:
driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
@ -87,29 +121,6 @@ class Theb(BaseProvider):
button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border") button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border")
button.click() button.click()
# Register fetch hook
script = """
window._fetch = window.fetch;
window.fetch = (url, options) => {
// Call parent fetch method
const result = window._fetch(url, options);
if (!url.startsWith("/api/conversation")) {
return result;
}
// Load response reader
result.then((response) => {
if (!response.body.locked) {
window._reader = response.body.getReader();
}
});
// Return dummy response
return new Promise((resolve, reject) => {
resolve(new Response(new ReadableStream()))
});
}
window._last_message = "";
"""
driver.execute_script(script)
# Submit prompt # Submit prompt
wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
@ -151,8 +162,3 @@ return '';
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
finally:
if not browser:
driver.close()
time.sleep(0.1)
driver.quit()

@ -1,24 +1,15 @@
from __future__ import annotations from __future__ import annotations
import warnings
import json import json
import asyncio from contextlib import asynccontextmanager
from functools import partialmethod from functools import partialmethod
from asyncio import Future, Queue from typing import AsyncGenerator
from typing import AsyncGenerator, Union, Optional
from curl_cffi.requests import AsyncSession, Response from curl_cffi.requests import AsyncSession, Response
import curl_cffi
is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl")
is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle")
is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl")
class StreamResponse: class StreamResponse:
def __init__(self, inner: Response, queue: Queue[bytes]) -> None: def __init__(self, inner: Response) -> None:
self.inner: Response = inner self.inner: Response = inner
self.queue: Queue[bytes] = queue
self.request = inner.request self.request = inner.request
self.status_code: int = inner.status_code self.status_code: int = inner.status_code
self.reason: str = inner.reason self.reason: str = inner.reason
@ -27,148 +18,32 @@ class StreamResponse:
self.cookies = inner.cookies self.cookies = inner.cookies
async def text(self) -> str: async def text(self) -> str:
content: bytes = await self.read() return await self.inner.atext()
return content.decode()
def raise_for_status(self) -> None: def raise_for_status(self) -> None:
if not self.ok: self.inner.raise_for_status()
raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}")
async def json(self, **kwargs) -> dict: async def json(self, **kwargs) -> dict:
return json.loads(await self.read(), **kwargs) return json.loads(await self.inner.acontent(), **kwargs)
async def iter_lines(
self, chunk_size: Optional[int] = None, decode_unicode: bool = False, delimiter: Optional[str] = None
) -> AsyncGenerator[bytes, None]:
"""
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending: bytes = None
async for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
if pending is not None:
chunk = pending + chunk
lines = chunk.split(delimiter) if delimiter else chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None
for line in lines: async def iter_lines(self) -> AsyncGenerator[bytes, None]:
async for line in self.inner.aiter_lines():
yield line yield line
if pending is not None: async def iter_content(self) -> AsyncGenerator[bytes, None]:
yield pending async for chunk in self.inner.aiter_content():
async def iter_content(
self, chunk_size: Optional[int] = None, decode_unicode: bool = False
) -> AsyncGenerator[bytes, None]:
if chunk_size:
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
raise NotImplementedError()
while True:
chunk = await self.queue.get()
if chunk is None:
return
yield chunk yield chunk
async def read(self) -> bytes:
return b"".join([chunk async for chunk in self.iter_content()])
class StreamRequest:
def __init__(self, session: AsyncSession, method: str, url: str, **kwargs: Union[bool, int, str]) -> None:
self.session: AsyncSession = session
self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop()
self.queue: Queue[bytes] = Queue()
self.method: str = method
self.url: str = url
self.options: dict = kwargs
self.handle: Optional[curl_cffi.AsyncCurl] = None
def _on_content(self, data: bytes) -> None:
if not self.enter.done():
self.enter.set_result(None)
self.queue.put_nowait(data)
def _on_done(self, task: Future) -> None:
if not self.enter.done():
self.enter.set_result(None)
self.queue.put_nowait(None)
self.loop.call_soon(self.release_curl)
async def fetch(self) -> StreamResponse:
if self.handle:
raise RuntimeError("Request already started")
self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl()
self.enter: asyncio.Future = self.loop.create_future()
if is_newer_0_5_10:
request, _, header_buffer, _, _ = self.session._set_curl_options(
self.curl,
self.method,
self.url,
content_callback=self._on_content,
**self.options
)
else:
request, _, header_buffer = self.session._set_curl_options(
self.curl,
self.method,
self.url,
content_callback=self._on_content,
**self.options
)
if is_newer_0_5_9:
self.handle = self.session.acurl.add_handle(self.curl)
else:
await self.session.acurl.add_handle(self.curl, False)
self.handle = self.session.acurl._curl2future[self.curl]
self.handle.add_done_callback(self._on_done)
# Wait for headers
await self.enter
# Raise exceptions
if self.handle.done():
self.handle.result()
if is_newer_0_5_8:
response = self.session._parse_response(self.curl, _, header_buffer)
response.request = request
else:
response = self.session._parse_response(self.curl, request, _, header_buffer)
return StreamResponse(response, self.queue)
async def __aenter__(self) -> StreamResponse:
return await self.fetch()
async def __aexit__(self, *args) -> None:
self.release_curl()
def release_curl(self) -> None:
if is_newer_0_5_10:
self.session.release_curl(self.curl)
return
if not self.curl:
return
self.curl.clean_after_perform()
if is_newer_0_5_9:
self.session.acurl.remove_handle(self.curl)
elif not self.handle.done() and not self.handle.cancelled():
self.session.acurl.set_result(self.curl)
self.curl.reset()
self.session.push_curl(self.curl)
self.curl = None
class StreamSession(AsyncSession): class StreamSession(AsyncSession):
def request( @asynccontextmanager
async def request(
self, method: str, url: str, **kwargs self, method: str, url: str, **kwargs
) -> StreamRequest: ) -> AsyncGenerator[StreamResponse]:
return StreamRequest(self, method, url, **kwargs) response = await super().request(method, url, stream=True, **kwargs)
try:
yield StreamResponse(response)
finally:
await response.aclose()
head = partialmethod(request, "HEAD") head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET") get = partialmethod(request, "GET")

@ -1,6 +1,6 @@
requests requests
pycryptodome pycryptodome
curl_cffi curl_cffi>=0.5.10b4
aiohttp aiohttp
certifi certifi
browser_cookie3 browser_cookie3
@ -22,3 +22,5 @@ fastapi
uvicorn uvicorn
flask flask
py-arkose-generator py-arkose-generator
asyncstdlib
async-property
Loading…
Cancel
Save