diff --git a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py index 8e23d935f8..284890579b 100644 --- a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py +++ b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py @@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM): typical_p: Optional[float] = 0.95 temperature: float = 0.8 repetition_penalty: Optional[float] = None + return_full_text: bool = False truncate: Optional[int] = None stop_sequences: List[str] = Field(default_factory=list) seed: Optional[int] = None @@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM): timeout: int = 120 server_kwargs: Dict[str, Any] = Field(default_factory=dict) streaming: bool = False + do_sample: bool = False + watermark: bool = False client: Any async_client: Any @@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM): "typical_p": self.typical_p, "temperature": self.temperature, "repetition_penalty": self.repetition_penalty, + "return_full_text": self.return_full_text, "truncate": self.truncate, "stop_sequences": self.stop_sequences, "seed": self.seed, + "do_sample": self.do_sample, + "watermark": self.watermark, } def _invocation_params(