mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add support of Cohere Embed v3 (#12940)
Cohere released the new embedding API (Embed v3: https://txt.cohere.com/introducing-embed-v3/) that treats document and query embeddings differently. This PR updated the `CohereEmbeddings` to use them appropriately. It also works with the old models.
This commit is contained in:
parent
8e0dcb37d2
commit
52d0055a91
@ -87,7 +87,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = self.client.embed(
|
||||
model=self.model, texts=texts, truncate=self.truncate
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type="search_document",
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
@ -101,7 +104,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model, texts=texts, truncate=self.truncate
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type="search_document",
|
||||
truncate=self.truncate,
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings]
|
||||
|
||||
@ -114,7 +120,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
embeddings = self.client.embed(
|
||||
model=self.model,
|
||||
texts=[text],
|
||||
input_type="search_query",
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings][0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
@ -125,5 +137,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=[text],
|
||||
input_type="search_query",
|
||||
truncate=self.truncate,
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings][0]
|
||||
|
Loading…
Reference in New Issue
Block a user