Add support of Cohere Embed v3 (#12940)

Cohere released the new embedding API (Embed v3:
https://txt.cohere.com/introducing-embed-v3/) that treats document and
query embeddings differently. This PR updated the `CohereEmbeddings` to
use them appropriately. It also works with the old models.
This commit is contained in:
Kacper Łukawski 2023-11-06 21:06:58 +01:00 committed by GitHub
parent 8e0dcb37d2
commit 52d0055a91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,7 +87,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
embeddings = self.client.embed(
model=self.model, texts=texts, truncate=self.truncate
model=self.model,
texts=texts,
input_type="search_document",
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
@ -101,7 +104,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
embeddings = await self.async_client.embed(
model=self.model, texts=texts, truncate=self.truncate
model=self.model,
texts=texts,
input_type="search_document",
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings]
@ -114,7 +120,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
embeddings = self.client.embed(
model=self.model,
texts=[text],
input_type="search_query",
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings][0]
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Cohere's embedding endpoint.
@ -125,5 +137,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
embeddings = await self.async_client.embed(
model=self.model,
texts=[text],
input_type="search_query",
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings][0]