mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add 'truncate' parameter for CohereEmbeddings (#798)
Currently, the 'truncate' parameter of the cohere API is not supported. This means that by default, if trying to generate and embedding that is too big, the call will just fail with an error (which is frustrating if using this embedding source e.g. with GPT-Index, because it's hard to handle it properly when generating a lot of embeddings). With the parameter, one can decide to either truncate the START or END of the text to fit the max token length and still generate an embedding without throwing the error. In this PR, I added this parameter to the class. _Arguably, there should be a better way to handle this error, e.g. by optionally calling a function or so that gets triggered when the token limit is reached and can split the document or some such. Especially in the use case with GPT-Index, its often hard to estimate the token counts for each document and I'd rather sort out the troublemakers or simply split them than interrupting the whole execution. Thoughts?_ --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
b9045f7e0d
commit
ebea40ce86
@ -25,6 +25,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
model: str = "large"
|
||||
"""Model name to use."""
|
||||
|
||||
truncate: str = "NONE"
|
||||
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
@ -58,7 +61,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = self.client.embed(model=self.model, texts=texts).embeddings
|
||||
embeddings = self.client.embed(
|
||||
model=self.model, texts=texts, truncate=self.truncate
|
||||
).embeddings
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -70,5 +75,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self.client.embed(model=self.model, texts=[text]).embeddings[0]
|
||||
embedding = self.client.embed(
|
||||
model=self.model, texts=[text], truncate=self.truncate
|
||||
).embeddings[0]
|
||||
return embedding
|
||||
|
Loading…
Reference in New Issue
Block a user