mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
feat(llms): add missing params to huggingface text-generation (#9724)
This small PR aims at supporting the following missing parameters in the `HuggingfaceTextGen` LLM: - `return_full_text` - sometimes useful for completion tasks - `do_sample` - quite handy to control the randomness of the model. - `watermark` @hwchase17 @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
491089754d
commit
a7c9bd30d4
@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
typical_p: Optional[float] = 0.95
|
||||
temperature: float = 0.8
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: bool = False
|
||||
truncate: Optional[int] = None
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
seed: Optional[int] = None
|
||||
@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM):
|
||||
timeout: int = 120
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
streaming: bool = False
|
||||
do_sample: bool = False
|
||||
watermark: bool = False
|
||||
client: Any
|
||||
async_client: Any
|
||||
|
||||
@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM):
|
||||
"typical_p": self.typical_p,
|
||||
"temperature": self.temperature,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"return_full_text": self.return_full_text,
|
||||
"truncate": self.truncate,
|
||||
"stop_sequences": self.stop_sequences,
|
||||
"seed": self.seed,
|
||||
"do_sample": self.do_sample,
|
||||
"watermark": self.watermark,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
|
Loading…
Reference in New Issue
Block a user