@ -1,25 +1,24 @@
from __future__ import annotations
import json , sys , asyncio
from functools import partialmethod
import warnings , json , asyncio
from aiohttp import StreamReader
from aiohttp . base_protocol import BaseProtocol
from functools import partialmethod
from asyncio import Future , Queue
from typing import AsyncGenerator
from curl_cffi . requests import AsyncSession as BaseSession
from curl_cffi . requests import Response
from curl_cffi . requests import AsyncSession , Response
import curl_cffi
is_newer_0_5_8 = hasattr ( Base Session, " _set_cookies " ) or hasattr ( curl_cffi . requests . Cookies , " get_cookies_for_curl " )
is_newer_0_5_8 = hasattr ( Async Session, " _set_cookies " ) or hasattr ( curl_cffi . requests . Cookies , " get_cookies_for_curl " )
is_newer_0_5_9 = hasattr ( curl_cffi . AsyncCurl , " remove_handle " )
is_newer_0_5_10 = hasattr ( Base Session, " release_curl " )
is_newer_0_5_10 = hasattr ( Async Session, " release_curl " )
class StreamResponse :
def __init__ ( self , inner : Response , content: StreamReader , request ) :
def __init__ ( self , inner : Response , queue: Queue ) :
self . inner = inner
self . content = content
self . request = request
self . queue = queue
self . request = inner. request
self . status_code = inner . status_code
self . reason = inner . reason
self . ok = inner . ok
@ -27,7 +26,7 @@ class StreamResponse:
self . cookies = inner . cookies
async def text ( self ) - > str :
content = await self . content. read( )
content = await self . read( )
return content . decode ( )
def raise_for_status ( self ) :
@ -35,36 +34,74 @@ class StreamResponse:
raise RuntimeError ( f " HTTP Error { self . status_code } : { self . reason } " )
async def json ( self , * * kwargs ) :
return json . loads ( await self . content . read ( ) , * * kwargs )
return json . loads ( await self . read ( ) , * * kwargs )
async def iter_lines ( self , chunk_size = None , decode_unicode = False , delimiter = None ) - > AsyncGenerator [ bytes ] :
"""
Copied from : https : / / requests . readthedocs . io / en / latest / _modules / requests / models /
which is under the License : Apache 2.0
"""
pending = None
async for chunk in self . iter_content (
chunk_size = chunk_size , decode_unicode = decode_unicode
) :
if pending is not None :
chunk = pending + chunk
if delimiter :
lines = chunk . split ( delimiter )
else :
lines = chunk . splitlines ( )
if lines and lines [ - 1 ] and chunk and lines [ - 1 ] [ - 1 ] == chunk [ - 1 ] :
pending = lines . pop ( )
else :
pending = None
for line in lines :
yield line
if pending is not None :
yield pending
async def iter_content ( self , chunk_size = None , decode_unicode = False ) - > As :
if chunk_size :
warnings . warn ( " chunk_size is ignored, there is no way to tell curl that. " )
if decode_unicode :
raise NotImplementedError ( )
while True :
chunk = await self . queue . get ( )
if chunk is None :
return
yield chunk
async def read ( self ) - > bytes :
return b " " . join ( [ chunk async for chunk in self . iter_content ( ) ] )
class StreamRequest :
def __init__ ( self , session : AsyncSession , method : str , url : str , * * kwargs ) :
self . session = session
self . loop = session . loop if session . loop else asyncio . get_running_loop ( )
self . content = StreamReader (
BaseProtocol ( session . loop ) ,
sys . maxsize ,
loop = session . loop
)
self . queue = Queue ( )
self . method = method
self . url = url
if " proxy " in kwargs :
proxy = kwargs . pop ( " proxy " )
if proxy :
kwargs [ " proxies " ] = { " http " : proxy , " https " : proxy }
self . options = kwargs
self . handle = None
def on_content ( self , data ) :
def _on_content ( self , data ) :
if not self . enter . done ( ) :
self . enter . set_result ( None )
self . content . feed_data ( data )
self . queue. put_nowait ( data )
def on_done ( self , task ) :
def _ on_done( self , task : Future ) :
if not self . enter . done ( ) :
self . enter . set_result ( None )
self . content . feed_eof ( )
self . queue. put_nowait ( None )
async def __aenter__ ( self ) - > StreamResponse :
self . loop . call_soon ( self . session . release_curl , self . curl )
async def fetch ( self ) - > StreamResponse :
if self . handle :
raise RuntimeError ( " Request already started " )
self . curl = await self . session . pop_curl ( )
self . enter = self . loop . create_future ( )
if is_newer_0_5_10 :
@ -72,7 +109,7 @@ class StreamRequest:
self . curl ,
self . method ,
self . url ,
content_callback = self . on_content,
content_callback = self . _ on_content,
* * self . options
)
else :
@ -80,7 +117,7 @@ class StreamRequest:
self . curl ,
self . method ,
self . url ,
content_callback = self . on_content,
content_callback = self . _ on_content,
* * self . options
)
if is_newer_0_5_9 :
@ -88,8 +125,12 @@ class StreamRequest:
else :
await self . session . acurl . add_handle ( self . curl , False )
self . handle = self . session . acurl . _curl2future [ self . curl ]
self . handle . add_done_callback ( self . on_done )
self . handle . add_done_callback ( self . _on_done )
# Wait for headers
await self . enter
# Raise exceptions
if self . handle . done ( ) :
self . handle . result ( )
if is_newer_0_5_8 :
response = self . session . _parse_response ( self . curl , _ , header_buffer )
response . request = request
@ -97,18 +138,16 @@ class StreamRequest:
response = self . session . _parse_response ( self . curl , request , _ , header_buffer )
return StreamResponse (
response ,
self . content ,
request
self . queue
)
async def __aexit__ ( self , exc_type , exc , tb ) :
if not self . handle . done ( ) :
self . session . acurl . set_result ( self . curl )
self . curl . clean_after_perform ( )
self . curl . reset ( )
self . session . push_curl ( self . curl )
class AsyncSession ( BaseSession ) :
async def __aenter__ ( self ) - > StreamResponse :
return await self . fetch ( )
async def __aexit__ ( self , * args ) :
self . session . release_curl ( self . curl )
class StreamSession ( AsyncSession ) :
def request (
self ,
method : str ,