From 98dd6d068a67c2ac1c14785ea189c2e4c8882bf5 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 5 Jun 2023 16:28:58 -0700 Subject: [PATCH] cohere retries (#5757) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …719) A minor update to retry Cohore API call in case of errors using tenacity as it is done for OpenAI LLMs. #### Who can review? @hwchase17, @agola11 Fixes # (issue) #### Before submitting #### Who can review? Tag maintainers/contributors who might be interested: --------- Co-authored-by: Sagar Sapkota <22609549+sagar-spkt@users.noreply.github.com> --- langchain/llms/cohere.py | 45 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 6da6cf9e..08043720 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -1,8 +1,17 @@ """Wrapper around Cohere APIs.""" +from __future__ import annotations + import logging -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import Extra, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -12,6 +21,33 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) +def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]: + import cohere + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type(cohere.error.CohereError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return llm.client.generate(**kwargs) + + return _completion_with_retry(**kwargs) + + class Cohere(LLM): """Wrapper around Cohere large language models. @@ -52,6 +88,9 @@ class Cohere(LLM): """Specify how the client handles inputs longer than the maximum token length: Truncate from START, END or NONE""" + max_retries: int = 10 + """Maximum number of retries to make when generating.""" + cohere_api_key: Optional[str] = None stop: Optional[List[str]] = None @@ -129,7 +168,9 @@ class Cohere(LLM): else: params["stop_sequences"] = stop - response = self.client.generate(model=self.model, prompt=prompt, **params) + response = completion_with_retry( + self, model=self.model, prompt=prompt, **params + ) text = response.generations[0].text # If stop tokens are provided, Cohere's endpoint returns them. # In order to make this consistent with other endpoints, we strip them.