fix: add retry to client for ratelimit (#73)

pull/82/head v0.1.2
Laurel Orr 1 year ago committed by GitHub
parent ee9f16688e
commit c7906bead5

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import requests
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
from manifest.request import DEFAULT_REQUEST_KEYS, NOT_CACHE_KEYS, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
@ -15,6 +16,14 @@ from manifest.response import RESPONSE_CONSTRUCTORS, Response
logger = logging.getLogger(__name__)
def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
"""Return whether to retry if ratelimited."""
if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
if retry_base.outcome.exception().response.status_code == 429: # type: ignore
return True
return False
class Client(ABC):
"""Client class."""
@ -194,6 +203,12 @@ class Client(ABC):
request_params_list.append(params)
return request_params_list
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
@ -207,25 +222,25 @@ class Client(ABC):
response as dict.
"""
post_str = self.get_generation_url()
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
try:
res = requests.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
)
res.raise_for_status()
except requests.Timeout as e:
logger.error(
f"{self.__class__.__name__} request timed out."
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json(), request_params)
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
@ -240,20 +255,16 @@ class Client(ABC):
response as dict.
"""
post_str = self.get_generation_url()
try:
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)
except aiohttp.ClientError as e:
logger.error(f"{self.__class__.__name__} request error {e}")
raise e
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
post_str,
headers=self.get_generation_header(),
json=request_params,
timeout=retry_timeout,
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.format_response(res_json, request_params)
def run_request(self, request: Request) -> Response:
"""

@ -21,7 +21,7 @@ with open(ver_path) as ver_file:
NAME = "manifest-ml"
DESCRIPTION = "Manifest for Prompting Foundation Models."
URL = "https://github.com/HazyResearch/manifest"
EMAIL = "lorr1@cs.stanford.edu"
EMAIL = "laurel.orr@numbersstation.ai"
AUTHOR = "Laurel Orr"
REQUIRES_PYTHON = ">=3.8.0"
VERSION = main_ns["__version__"]
@ -34,8 +34,9 @@ REQUIRED = [
"requests>=2.27.1",
"aiohttp>=3.8.0",
"sqlitedict>=2.0.0",
"xxhash>=3.0.0",
"tenacity>=8.2.0",
"tiktoken>=0.3.0",
"xxhash>=3.0.0",
]
# What packages are optional?

@ -1,7 +1,7 @@
"""
Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
We just test the dummy client.
"""
from manifest.clients.dummy import DummyClient

@ -2,9 +2,11 @@
import asyncio
import os
from typing import cast
from unittest.mock import MagicMock, Mock, patch
import pytest
import requests
from requests import HTTPError
from manifest import Manifest, Response
from manifest.caches.noop import NoopCache
@ -643,7 +645,7 @@ def test_openaichat(sqlite_cache: str) -> None:
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
assert response.is_cached() is True
assert "usage" in response.get_json_response()
assert response.get_json_response()["usage"][0]["total_tokens"] == 22
assert response.get_json_response()["usage"][0]["total_tokens"] == 23
response = cast(Response, client.run("Why are there apples?", return_response=True))
assert response.is_cached() is True
@ -674,10 +676,92 @@ def test_openaichat(sqlite_cache: str) -> None:
"usage" in response.get_json_response()
and len(response.get_json_response()["usage"]) == 2
)
assert response.get_json_response()["usage"][0]["total_tokens"] == 24
assert response.get_json_response()["usage"][1]["total_tokens"] == 22
assert response.get_json_response()["usage"][0]["total_tokens"] == 25
assert response.get_json_response()["usage"][1]["total_tokens"] == 23
response = cast(
Response, client.run("Why are there oranges?", return_response=True)
)
assert response.is_cached() is True
def test_retry_handling() -> None:
"""Test retry handling."""
# We'll mock the response so we won't need a real connection
client = Manifest(client_name="openai", client_connection="fake")
mock_create = MagicMock(
side_effect=[
# raise a 429 error
HTTPError(
response=Mock(status_code=429, json=Mock(return_value={})),
request=Mock(),
),
# get a valid http response with a 200 status code
Mock(
status_code=200,
json=Mock(
return_value={
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": None,
"text": " WHATTT.",
},
{
"finish_reason": "length",
"index": 1,
"logprobs": None,
"text": " UH OH.",
},
{
"finish_reason": "length",
"index": 2,
"logprobs": None,
"text": " HARG",
},
],
"created": 1679469056,
"id": "cmpl-6wmuWfmyuzi68B6gfeNC0h5ywxXL5",
"model": "text-ada-001",
"object": "text_completion",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 24,
"total_tokens": 54,
},
}
),
),
]
)
prompts = [
"The sky is purple. This is because",
"The sky is magnet. This is because",
"The sky is fuzzy. This is because",
]
with patch("manifest.clients.client.requests.post", mock_create):
# Run manifest
result = client.run(prompts, temperature=0, overwrite_cache=True)
assert result == ["WHATTT.", "UH OH.", "HARG"]
# Assert that OpenAI client was called twice
assert mock_create.call_count == 2
# Now make sure it errors when not a 429
mock_create = MagicMock(
side_effect=[
# raise a 500 error
HTTPError(
response=Mock(status_code=500, json=Mock(return_value={})),
request=Mock(),
),
]
)
with patch("manifest.clients.client.requests.post", mock_create):
# Run manifest
with pytest.raises(HTTPError):
client.run(prompts, temperature=0, overwrite_cache=True)
# Assert that OpenAI client was called once
assert mock_create.call_count == 1

Loading…
Cancel
Save