diff --git a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py index 284890579b..6545078f16 100644 --- a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py +++ b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py @@ -1,3 +1,4 @@ +import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Optional from langchain.callbacks.manager import ( @@ -7,89 +8,91 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.schema.output import GenerationChunk +from langchain.utils import get_pydantic_field_names + +logger = logging.getLogger(__name__) class HuggingFaceTextGenInference(LLM): """ HuggingFace text generation API. - It generates text from a given prompt. + To use, you should have the `text-generation` python package installed and + a text-generation server running. - Attributes: - - max_new_tokens: The maximum number of tokens to generate. - - top_k: The number of top-k tokens to consider when generating text. - - top_p: The cumulative probability threshold for generating text. - - typical_p: The typical probability threshold for generating text. - - temperature: The temperature to use when generating text. - - repetition_penalty: The repetition penalty to use when generating text. - - truncate: truncate inputs tokens to the given size - - stop_sequences: A list of stop sequences to use when generating text. - - seed: The seed to use when generating text. - - inference_server_url: The URL of the inference server to use. - - timeout: The timeout value in seconds to use while connecting to inference server. - - server_kwargs: The keyword arguments to pass to the inference server. - - client: The client object used to communicate with the inference server. - - async_client: The async client object used to communicate with the server. - - Methods: - - _call: Generates text based on a given prompt and stop sequences. - - _acall: Async generates text based on a given prompt and stop sequences. - - _llm_type: Returns the type of LLM. - - _default_params: Returns the default parameters for calling text generation - inference API. - """ - - """ Example: .. code-block:: python # Basic Example (no streaming) llm = HuggingFaceTextGenInference( - inference_server_url = "http://localhost:8010/", - max_new_tokens = 512, - top_k = 10, - top_p = 0.95, - typical_p = 0.95, - temperature = 0.01, - repetition_penalty = 1.03, + inference_server_url="http://localhost:8010/", + max_new_tokens=512, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=0.01, + repetition_penalty=1.03, ) print(llm("What is Deep Learning?")) - + # Streaming response example from langchain.callbacks import streaming_stdout - + callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] llm = HuggingFaceTextGenInference( - inference_server_url = "http://localhost:8010/", - max_new_tokens = 512, - top_k = 10, - top_p = 0.95, - typical_p = 0.95, - temperature = 0.01, - repetition_penalty = 1.03, - callbacks = callbacks, - streaming = True + inference_server_url="http://localhost:8010/", + max_new_tokens=512, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=0.01, + repetition_penalty=1.03, + callbacks=callbacks, + streaming=True ) print(llm("What is Deep Learning?")) - + """ max_new_tokens: int = 512 + """Maximum number of generated tokens""" top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for + top-k-filtering.""" top_p: Optional[float] = 0.95 + """If set to < 1, only the smallest set of most probable tokens with probabilities + that add up to `top_p` or higher are kept for generation.""" typical_p: Optional[float] = 0.95 + """Typical Decoding mass. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information.""" temperature: float = 0.8 + """The value used to module the logits distribution.""" repetition_penalty: Optional[float] = None + """The parameter for repetition penalty. 1.0 means no penalty. + See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.""" return_full_text: bool = False + """Whether to prepend the prompt to the generated text""" truncate: Optional[int] = None + """Truncate inputs tokens to the given size""" stop_sequences: List[str] = Field(default_factory=list) + """Stop generating tokens if a member of `stop_sequences` is generated""" seed: Optional[int] = None + """Random sampling seed""" inference_server_url: str = "" + """text-generation-inference instance base url""" timeout: int = 120 - server_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Timeout in seconds""" streaming: bool = False + """Whether to generate a stream of tokens asynchronously""" do_sample: bool = False + """Activate logits sampling""" watermark: bool = False + """Watermarking with [A Watermark for Large Language Models] + (https://arxiv.org/abs/2301.10226)""" + server_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any text-generation-inference server parameters not explicitly specified""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `call` not explicitly specified""" client: Any async_client: Any @@ -98,6 +101,32 @@ class HuggingFaceTextGenInference(LLM): extra = Extra.forbid + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" @@ -143,6 +172,7 @@ class HuggingFaceTextGenInference(LLM): "seed": self.seed, "do_sample": self.do_sample, "watermark": self.watermark, + **self.model_kwargs, } def _invocation_params(