diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index b55e4aaa..de3648a9 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -25,6 +25,9 @@ class CohereEmbeddings(BaseModel, Embeddings): model: str = "large" """Model name to use.""" + truncate: str = "NONE" + """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" + cohere_api_key: Optional[str] = None class Config: @@ -58,7 +61,9 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = self.client.embed(model=self.model, texts=texts).embeddings + embeddings = self.client.embed( + model=self.model, texts=texts, truncate=self.truncate + ).embeddings return embeddings def embed_query(self, text: str) -> List[float]: @@ -70,5 +75,7 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embedding = self.client.embed(model=self.model, texts=[text]).embeddings[0] + embedding = self.client.embed( + model=self.model, texts=[text], truncate=self.truncate + ).embeddings[0] return embedding