CohereEmbeddings: Add max_retries and request_timeout (#12275)

Add max_retries and request_timeout to CohereEmbeddings, akin to how it
works in OpenAIEmbeddings.

Since the Cohere client already implements these parameters, we can
simply pass them down.

Uses parameters from these two cohere client objects:

https://github.com/cohere-ai/cohere-python/blob/main/cohere/client.py

https://github.com/cohere-ai/cohere-python/blob/main/cohere/client_async.py
pull/12286/head^2
Johanna Appel 12 months ago committed by GitHub
parent 7108084947
commit c26ec7789f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,6 +33,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key: Optional[str] = None
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
class Config:
"""Configuration for this pydantic object."""
@ -44,11 +49,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
max_retries = values.get("max_retries")
request_timeout = values.get("request_timeout")
try:
import cohere
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
values["client"] = cohere.Client(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
)
except ImportError:
raise ValueError(
"Could not import cohere python package. "

Loading…
Cancel
Save