|
|
@ -25,6 +25,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
- typical_p: The typical probability threshold for generating text.
|
|
|
|
- typical_p: The typical probability threshold for generating text.
|
|
|
|
- temperature: The temperature to use when generating text.
|
|
|
|
- temperature: The temperature to use when generating text.
|
|
|
|
- repetition_penalty: The repetition penalty to use when generating text.
|
|
|
|
- repetition_penalty: The repetition penalty to use when generating text.
|
|
|
|
|
|
|
|
- truncate: truncate inputs tokens to the given size
|
|
|
|
- stop_sequences: A list of stop sequences to use when generating text.
|
|
|
|
- stop_sequences: A list of stop sequences to use when generating text.
|
|
|
|
- seed: The seed to use when generating text.
|
|
|
|
- seed: The seed to use when generating text.
|
|
|
|
- inference_server_url: The URL of the inference server to use.
|
|
|
|
- inference_server_url: The URL of the inference server to use.
|
|
|
@ -80,6 +81,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
typical_p: Optional[float] = 0.95
|
|
|
|
typical_p: Optional[float] = 0.95
|
|
|
|
temperature: float = 0.8
|
|
|
|
temperature: float = 0.8
|
|
|
|
repetition_penalty: Optional[float] = None
|
|
|
|
repetition_penalty: Optional[float] = None
|
|
|
|
|
|
|
|
truncate: Optional[int] = None
|
|
|
|
stop_sequences: List[str] = Field(default_factory=list)
|
|
|
|
stop_sequences: List[str] = Field(default_factory=list)
|
|
|
|
seed: Optional[int] = None
|
|
|
|
seed: Optional[int] = None
|
|
|
|
inference_server_url: str = ""
|
|
|
|
inference_server_url: str = ""
|
|
|
@ -145,6 +147,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
typical_p=self.typical_p,
|
|
|
|
typical_p=self.typical_p,
|
|
|
|
temperature=self.temperature,
|
|
|
|
temperature=self.temperature,
|
|
|
|
repetition_penalty=self.repetition_penalty,
|
|
|
|
repetition_penalty=self.repetition_penalty,
|
|
|
|
|
|
|
|
truncate=self.truncate,
|
|
|
|
seed=self.seed,
|
|
|
|
seed=self.seed,
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -169,6 +172,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
"typical_p": self.typical_p,
|
|
|
|
"typical_p": self.typical_p,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"repetition_penalty": self.repetition_penalty,
|
|
|
|
"repetition_penalty": self.repetition_penalty,
|
|
|
|
|
|
|
|
"truncate": self.truncate,
|
|
|
|
"seed": self.seed,
|
|
|
|
"seed": self.seed,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
text = ""
|
|
|
|
text = ""
|
|
|
@ -209,6 +213,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
typical_p=self.typical_p,
|
|
|
|
typical_p=self.typical_p,
|
|
|
|
temperature=self.temperature,
|
|
|
|
temperature=self.temperature,
|
|
|
|
repetition_penalty=self.repetition_penalty,
|
|
|
|
repetition_penalty=self.repetition_penalty,
|
|
|
|
|
|
|
|
truncate=self.truncate,
|
|
|
|
seed=self.seed,
|
|
|
|
seed=self.seed,
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -234,6 +239,7 @@ class HuggingFaceTextGenInference(LLM):
|
|
|
|
"typical_p": self.typical_p,
|
|
|
|
"typical_p": self.typical_p,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"repetition_penalty": self.repetition_penalty,
|
|
|
|
"repetition_penalty": self.repetition_penalty,
|
|
|
|
|
|
|
|
"truncate": self.truncate,
|
|
|
|
"seed": self.seed,
|
|
|
|
"seed": self.seed,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|