feat(llms): add missing params to huggingface text-generation (#9724)

This small PR aims at supporting the following missing parameters in the
`HuggingfaceTextGen` LLM:
- `return_full_text` - sometimes useful for completion tasks
- `do_sample` - quite handy to control the randomness of the model.
- `watermark`

@hwchase17 @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Massimiliano Pronesti 2023-09-01 22:16:27 +02:00 committed by GitHub
parent 491089754d
commit a7c9bd30d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM):
typical_p: Optional[float] = 0.95
temperature: float = 0.8
repetition_penalty: Optional[float] = None
return_full_text: bool = False
truncate: Optional[int] = None
stop_sequences: List[str] = Field(default_factory=list)
seed: Optional[int] = None
@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM):
timeout: int = 120
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
streaming: bool = False
do_sample: bool = False
watermark: bool = False
client: Any
async_client: Any
@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM):
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"return_full_text": self.return_full_text,
"truncate": self.truncate,
"stop_sequences": self.stop_sequences,
"seed": self.seed,
"do_sample": self.do_sample,
"watermark": self.watermark,
}
def _invocation_params(