add support for truncate arg for HuggingFaceTextGenInference class (#7728)

Fixes https://github.com/hwchase17/langchain/issues/7650

* add support for `truncate` argument of `HugginFaceTextGenInference`

@baskaryan
This commit is contained in:
Bearnardd 2023-07-14 22:23:56 +02:00 committed by GitHub
parent 77e6bbe6f0
commit 9800c6051c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,6 +25,7 @@ class HuggingFaceTextGenInference(LLM):
- typical_p: The typical probability threshold for generating text.
- temperature: The temperature to use when generating text.
- repetition_penalty: The repetition penalty to use when generating text.
- truncate: truncate inputs tokens to the given size
- stop_sequences: A list of stop sequences to use when generating text.
- seed: The seed to use when generating text.
- inference_server_url: The URL of the inference server to use.
@ -80,6 +81,7 @@ class HuggingFaceTextGenInference(LLM):
typical_p: Optional[float] = 0.95
temperature: float = 0.8
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = Field(default_factory=list)
seed: Optional[int] = None
inference_server_url: str = ""
@ -145,6 +147,7 @@ class HuggingFaceTextGenInference(LLM):
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
truncate=self.truncate,
seed=self.seed,
**kwargs,
)
@ -169,6 +172,7 @@ class HuggingFaceTextGenInference(LLM):
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"truncate": self.truncate,
"seed": self.seed,
}
text = ""
@ -209,6 +213,7 @@ class HuggingFaceTextGenInference(LLM):
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
truncate=self.truncate,
seed=self.seed,
**kwargs,
)
@ -234,6 +239,7 @@ class HuggingFaceTextGenInference(LLM):
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"truncate": self.truncate,
"seed": self.seed,
},
**kwargs,