diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index de3648a9..c6d5055f 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -25,7 +25,7 @@ class CohereEmbeddings(BaseModel, Embeddings): model: str = "large" """Model name to use.""" - truncate: str = "NONE" + truncate: Optional[str] = None """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" cohere_api_key: Optional[str] = None diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index adb50ad3..de246e22 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -47,6 +47,10 @@ class Cohere(LLM, BaseModel): presence_penalty: int = 0 """Penalizes repeated tokens.""" + truncate: Optional[str] = None + """Specify how the client handles inputs longer than the maximum token + length: Truncate from START, END or NONE""" + cohere_api_key: Optional[str] = None stop: Optional[List[str]] = None @@ -83,6 +87,7 @@ class Cohere(LLM, BaseModel): "p": self.p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "truncate": self.truncate, } @property