feat(llms): add model_kwargs to hf tgi (#10139)

@baskaryan
Following what we discussed in #9724 and your suggestion, I've added a
`model_kwargs` parameter to hf tgi.
This commit is contained in:
Massimiliano Pronesti 2023-09-04 09:24:13 +02:00 committed by GitHub
parent e0f6ba08d6
commit 10e0431e48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,4 @@
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -7,89 +8,91 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.utils import get_pydantic_field_names
logger = logging.getLogger(__name__)
class HuggingFaceTextGenInference(LLM): class HuggingFaceTextGenInference(LLM):
""" """
HuggingFace text generation API. HuggingFace text generation API.
It generates text from a given prompt. To use, you should have the `text-generation` python package installed and
a text-generation server running.
Attributes:
- max_new_tokens: The maximum number of tokens to generate.
- top_k: The number of top-k tokens to consider when generating text.
- top_p: The cumulative probability threshold for generating text.
- typical_p: The typical probability threshold for generating text.
- temperature: The temperature to use when generating text.
- repetition_penalty: The repetition penalty to use when generating text.
- truncate: truncate inputs tokens to the given size
- stop_sequences: A list of stop sequences to use when generating text.
- seed: The seed to use when generating text.
- inference_server_url: The URL of the inference server to use.
- timeout: The timeout value in seconds to use while connecting to inference server.
- server_kwargs: The keyword arguments to pass to the inference server.
- client: The client object used to communicate with the inference server.
- async_client: The async client object used to communicate with the server.
Methods:
- _call: Generates text based on a given prompt and stop sequences.
- _acall: Async generates text based on a given prompt and stop sequences.
- _llm_type: Returns the type of LLM.
- _default_params: Returns the default parameters for calling text generation
inference API.
"""
"""
Example: Example:
.. code-block:: python .. code-block:: python
# Basic Example (no streaming) # Basic Example (no streaming)
llm = HuggingFaceTextGenInference( llm = HuggingFaceTextGenInference(
inference_server_url = "http://localhost:8010/", inference_server_url="http://localhost:8010/",
max_new_tokens = 512, max_new_tokens=512,
top_k = 10, top_k=10,
top_p = 0.95, top_p=0.95,
typical_p = 0.95, typical_p=0.95,
temperature = 0.01, temperature=0.01,
repetition_penalty = 1.03, repetition_penalty=1.03,
) )
print(llm("What is Deep Learning?")) print(llm("What is Deep Learning?"))
# Streaming response example # Streaming response example
from langchain.callbacks import streaming_stdout from langchain.callbacks import streaming_stdout
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
llm = HuggingFaceTextGenInference( llm = HuggingFaceTextGenInference(
inference_server_url = "http://localhost:8010/", inference_server_url="http://localhost:8010/",
max_new_tokens = 512, max_new_tokens=512,
top_k = 10, top_k=10,
top_p = 0.95, top_p=0.95,
typical_p = 0.95, typical_p=0.95,
temperature = 0.01, temperature=0.01,
repetition_penalty = 1.03, repetition_penalty=1.03,
callbacks = callbacks, callbacks=callbacks,
streaming = True streaming=True
) )
print(llm("What is Deep Learning?")) print(llm("What is Deep Learning?"))
""" """
max_new_tokens: int = 512 max_new_tokens: int = 512
"""Maximum number of generated tokens"""
top_k: Optional[int] = None top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for
top-k-filtering."""
top_p: Optional[float] = 0.95 top_p: Optional[float] = 0.95
"""If set to < 1, only the smallest set of most probable tokens with probabilities
that add up to `top_p` or higher are kept for generation."""
typical_p: Optional[float] = 0.95 typical_p: Optional[float] = 0.95
"""Typical Decoding mass. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information."""
temperature: float = 0.8 temperature: float = 0.8
"""The value used to module the logits distribution."""
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
"""The parameter for repetition penalty. 1.0 means no penalty.
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
return_full_text: bool = False return_full_text: bool = False
"""Whether to prepend the prompt to the generated text"""
truncate: Optional[int] = None truncate: Optional[int] = None
"""Truncate inputs tokens to the given size"""
stop_sequences: List[str] = Field(default_factory=list) stop_sequences: List[str] = Field(default_factory=list)
"""Stop generating tokens if a member of `stop_sequences` is generated"""
seed: Optional[int] = None seed: Optional[int] = None
"""Random sampling seed"""
inference_server_url: str = "" inference_server_url: str = ""
"""text-generation-inference instance base url"""
timeout: int = 120 timeout: int = 120
server_kwargs: Dict[str, Any] = Field(default_factory=dict) """Timeout in seconds"""
streaming: bool = False streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
do_sample: bool = False do_sample: bool = False
"""Activate logits sampling"""
watermark: bool = False watermark: bool = False
"""Watermarking with [A Watermark for Large Language Models]
(https://arxiv.org/abs/2301.10226)"""
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any text-generation-inference server parameters not explicitly specified"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
client: Any client: Any
async_client: Any async_client: Any
@ -98,6 +101,32 @@ class HuggingFaceTextGenInference(LLM):
extra = Extra.forbid extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment.""" """Validate that python package exists in environment."""
@ -143,6 +172,7 @@ class HuggingFaceTextGenInference(LLM):
"seed": self.seed, "seed": self.seed,
"do_sample": self.do_sample, "do_sample": self.do_sample,
"watermark": self.watermark, "watermark": self.watermark,
**self.model_kwargs,
} }
def _invocation_params( def _invocation_params(