From 9800c6051c4962647a75a1bd045a53b298c923e1 Mon Sep 17 00:00:00 2001 From: Bearnardd <43574448+Bearnardd@users.noreply.github.com> Date: Fri, 14 Jul 2023 22:23:56 +0200 Subject: [PATCH] add support for truncate arg for HuggingFaceTextGenInference class (#7728) Fixes https://github.com/hwchase17/langchain/issues/7650 * add support for `truncate` argument of `HugginFaceTextGenInference` @baskaryan --- langchain/llms/huggingface_text_gen_inference.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 10bc73613c..5f33668790 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -25,6 +25,7 @@ class HuggingFaceTextGenInference(LLM): - typical_p: The typical probability threshold for generating text. - temperature: The temperature to use when generating text. - repetition_penalty: The repetition penalty to use when generating text. + - truncate: truncate inputs tokens to the given size - stop_sequences: A list of stop sequences to use when generating text. - seed: The seed to use when generating text. - inference_server_url: The URL of the inference server to use. @@ -80,6 +81,7 @@ class HuggingFaceTextGenInference(LLM): typical_p: Optional[float] = 0.95 temperature: float = 0.8 repetition_penalty: Optional[float] = None + truncate: Optional[int] = None stop_sequences: List[str] = Field(default_factory=list) seed: Optional[int] = None inference_server_url: str = "" @@ -145,6 +147,7 @@ class HuggingFaceTextGenInference(LLM): typical_p=self.typical_p, temperature=self.temperature, repetition_penalty=self.repetition_penalty, + truncate=self.truncate, seed=self.seed, **kwargs, ) @@ -169,6 +172,7 @@ class HuggingFaceTextGenInference(LLM): "typical_p": self.typical_p, "temperature": self.temperature, "repetition_penalty": self.repetition_penalty, + "truncate": self.truncate, "seed": self.seed, } text = "" @@ -209,6 +213,7 @@ class HuggingFaceTextGenInference(LLM): typical_p=self.typical_p, temperature=self.temperature, repetition_penalty=self.repetition_penalty, + truncate=self.truncate, seed=self.seed, **kwargs, ) @@ -234,6 +239,7 @@ class HuggingFaceTextGenInference(LLM): "typical_p": self.typical_p, "temperature": self.temperature, "repetition_penalty": self.repetition_penalty, + "truncate": self.truncate, "seed": self.seed, }, **kwargs,