community[patch]: Fixed the 'aembed' method of 'CohereEmbeddings'. (#16497)

**Description:**
- The existing code was trying to find a `.embeddings` property on the
`Coroutine` returned by calling `cohere.async_client.embed`.
- Instead, the `.embeddings` property is present on the value returned
by the `Coroutine`.
- Also, it seems that the original cohere client expects a value of
`max_retries` to not be `None`. Hence, setting the default value of
`max_retries` to `3`.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/17434/head
Abhishek Jain 8 months ago committed by GitHub
parent 9f1cbbc6ed
commit 37e1275f9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -34,7 +34,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key: Optional[str] = None
max_retries: Optional[int] = None
max_retries: Optional[int] = 3
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
@ -92,11 +92,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
async def aembed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = await self.async_client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
embeddings = (
await self.async_client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
)
).embeddings
return [list(map(float, e)) for e in embeddings]

Loading…
Cancel
Save