|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
import requests
|
|
|
|
|
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
|
|
|
|
|
|
|
|
|
|
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
|
|
|
|
|
from manifest.response import RESPONSE_CONSTRUCTORS, Response
|
|
|
|
@ -15,6 +16,14 @@ from manifest.response import RESPONSE_CONSTRUCTORS, Response
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
|
|
|
|
|
"""Return whether to retry if ratelimited."""
|
|
|
|
|
if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
|
|
|
|
|
if retry_base.outcome.exception().response.status_code == 429: # type: ignore
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Client(ABC):
|
|
|
|
|
"""Client class."""
|
|
|
|
|
|
|
|
|
@ -194,6 +203,12 @@ class Client(ABC):
|
|
|
|
|
request_params_list.append(params)
|
|
|
|
|
return request_params_list
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
|
reraise=True,
|
|
|
|
|
retry=retry_if_ratelimit,
|
|
|
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
|
|
|
stop=stop_after_attempt(10),
|
|
|
|
|
)
|
|
|
|
|
def _run_completion(
|
|
|
|
|
self, request_params: Dict[str, Any], retry_timeout: int
|
|
|
|
|
) -> Dict:
|
|
|
|
@ -207,25 +222,25 @@ class Client(ABC):
|
|
|
|
|
response as dict.
|
|
|
|
|
"""
|
|
|
|
|
post_str = self.get_generation_url()
|
|
|
|
|
res = requests.post(
|
|
|
|
|
post_str,
|
|
|
|
|
headers=self.get_generation_header(),
|
|
|
|
|
json=request_params,
|
|
|
|
|
timeout=retry_timeout,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
res = requests.post(
|
|
|
|
|
post_str,
|
|
|
|
|
headers=self.get_generation_header(),
|
|
|
|
|
json=request_params,
|
|
|
|
|
timeout=retry_timeout,
|
|
|
|
|
)
|
|
|
|
|
res.raise_for_status()
|
|
|
|
|
except requests.Timeout as e:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"{self.__class__.__name__} request timed out."
|
|
|
|
|
" Increase client_timeout."
|
|
|
|
|
)
|
|
|
|
|
raise e
|
|
|
|
|
except requests.exceptions.HTTPError:
|
|
|
|
|
logger.error(res.json())
|
|
|
|
|
raise requests.exceptions.HTTPError(res.json())
|
|
|
|
|
return self.format_response(res.json(), request_params)
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
|
reraise=True,
|
|
|
|
|
retry=retry_if_ratelimit,
|
|
|
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
|
|
|
stop=stop_after_attempt(10),
|
|
|
|
|
)
|
|
|
|
|
async def _arun_completion(
|
|
|
|
|
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
|
|
|
|
|
) -> Dict:
|
|
|
|
@ -240,20 +255,16 @@ class Client(ABC):
|
|
|
|
|
response as dict.
|
|
|
|
|
"""
|
|
|
|
|
post_str = self.get_generation_url()
|
|
|
|
|
try:
|
|
|
|
|
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
|
|
|
|
|
async with session.post(
|
|
|
|
|
post_str,
|
|
|
|
|
headers=self.get_generation_header(),
|
|
|
|
|
json=request_params,
|
|
|
|
|
timeout=retry_timeout,
|
|
|
|
|
) as res:
|
|
|
|
|
res.raise_for_status()
|
|
|
|
|
res_json = await res.json(content_type=None)
|
|
|
|
|
return self.format_response(res_json, request_params)
|
|
|
|
|
except aiohttp.ClientError as e:
|
|
|
|
|
logger.error(f"{self.__class__.__name__} request error {e}")
|
|
|
|
|
raise e
|
|
|
|
|
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
|
|
|
|
|
async with session.post(
|
|
|
|
|
post_str,
|
|
|
|
|
headers=self.get_generation_header(),
|
|
|
|
|
json=request_params,
|
|
|
|
|
timeout=retry_timeout,
|
|
|
|
|
) as res:
|
|
|
|
|
res.raise_for_status()
|
|
|
|
|
res_json = await res.json(content_type=None)
|
|
|
|
|
return self.format_response(res_json, request_params)
|
|
|
|
|
|
|
|
|
|
def run_request(self, request: Request) -> Response:
|
|
|
|
|
"""
|
|
|
|
|