cohere retries (#5757)

…719)

A minor update to retry Cohore API call in case of errors using tenacity
as it is done for OpenAI LLMs.

#### Who can review?

@hwchase17, @agola11 

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

<!-- Remove if not applicable -->

Fixes # (issue)

#### Before submitting

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

#### Who can review?

Tag maintainers/contributors who might be interested:

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->

---------

Co-authored-by: Sagar Sapkota <22609549+sagar-spkt@users.noreply.github.com>
searx_updates
Harrison Chase 12 months ago committed by GitHub
parent 5124c1e0d9
commit 98dd6d068a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,17 @@
"""Wrapper around Cohere APIs."""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from pydantic import Extra, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@ -12,6 +21,33 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
import cohere
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.error.CohereError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return llm.client.generate(**kwargs)
return _completion_with_retry(**kwargs)
class Cohere(LLM):
"""Wrapper around Cohere large language models.
@ -52,6 +88,9 @@ class Cohere(LLM):
"""Specify how the client handles inputs longer than the maximum token
length: Truncate from START, END or NONE"""
max_retries: int = 10
"""Maximum number of retries to make when generating."""
cohere_api_key: Optional[str] = None
stop: Optional[List[str]] = None
@ -129,7 +168,9 @@ class Cohere(LLM):
else:
params["stop_sequences"] = stop
response = self.client.generate(model=self.model, prompt=prompt, **params)
response = completion_with_retry(
self, model=self.model, prompt=prompt, **params
)
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.

Loading…
Cancel
Save