diff --git a/libs/community/langchain_community/embeddings/cohere.py b/libs/community/langchain_community/embeddings/cohere.py index fd95b58ea6..2d4676d125 100644 --- a/libs/community/langchain_community/embeddings/cohere.py +++ b/libs/community/langchain_community/embeddings/cohere.py @@ -34,7 +34,7 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key: Optional[str] = None - max_retries: Optional[int] = None + max_retries: Optional[int] = 3 """Maximum number of retries to make when generating.""" request_timeout: Optional[float] = None """Timeout in seconds for the Cohere API request.""" @@ -92,11 +92,13 @@ class CohereEmbeddings(BaseModel, Embeddings): async def aembed( self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: - embeddings = await self.async_client.embed( - model=self.model, - texts=texts, - input_type=input_type, - truncate=self.truncate, + embeddings = ( + await self.async_client.embed( + model=self.model, + texts=texts, + input_type=input_type, + truncate=self.truncate, + ) ).embeddings return [list(map(float, e)) for e in embeddings]