Add 'truncate' parameter for CohereEmbeddings (#798)

Currently, the 'truncate' parameter of the cohere API is not supported.

This means that by default, if trying to generate and embedding that is
too big, the call will just fail with an error (which is frustrating if
using this embedding source e.g. with GPT-Index, because it's hard to
handle it properly when generating a lot of embeddings).
With the parameter, one can decide to either truncate the START or END
of the text to fit the max token length and still generate an embedding
without throwing the error.

In this PR, I added this parameter to the class.

_Arguably, there should be a better way to handle this error, e.g. by
optionally calling a function or so that gets triggered when the token
limit is reached and can split the document or some such. Especially in
the use case with GPT-Index, its often hard to estimate the token counts
for each document and I'd rather sort out the troublemakers or simply
split them than interrupting the whole execution.
Thoughts?_

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
harrison/image
Johanna Appel 1 year ago committed by GitHub
parent b9045f7e0d
commit ebea40ce86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,6 +25,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
model: str = "large"
"""Model name to use."""
truncate: str = "NONE"
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
cohere_api_key: Optional[str] = None
class Config:
@ -58,7 +61,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = self.client.embed(model=self.model, texts=texts).embeddings
embeddings = self.client.embed(
model=self.model, texts=texts, truncate=self.truncate
).embeddings
return embeddings
def embed_query(self, text: str) -> List[float]:
@ -70,5 +75,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embedding = self.client.embed(model=self.model, texts=[text]).embeddings[0]
embedding = self.client.embed(
model=self.model, texts=[text], truncate=self.truncate
).embeddings[0]
return embedding

Loading…
Cancel
Save