diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index 1c908876f9..1a6b3cdca0 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -33,6 +33,11 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key: Optional[str] = None + max_retries: Optional[int] = None + """Maximum number of retries to make when generating.""" + request_timeout: Optional[float] = None + """Timeout in seconds for the Cohere API request.""" + class Config: """Configuration for this pydantic object.""" @@ -44,11 +49,18 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) + max_retries = values.get("max_retries") + request_timeout = values.get("request_timeout") + try: import cohere - values["client"] = cohere.Client(cohere_api_key) - values["async_client"] = cohere.AsyncClient(cohere_api_key) + values["client"] = cohere.Client( + cohere_api_key, max_retries=max_retries, timeout=request_timeout + ) + values["async_client"] = cohere.AsyncClient( + cohere_api_key, max_retries=max_retries, timeout=request_timeout + ) except ImportError: raise ValueError( "Could not import cohere python package. "