mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
feat(llms): add model_kwargs to hf tgi (#10139)
@baskaryan Following what we discussed in #9724 and your suggestion, I've added a `model_kwargs` parameter to hf tgi.
This commit is contained in:
parent
e0f6ba08d6
commit
10e0431e48
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -7,51 +8,30 @@ from langchain.callbacks.manager import (
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceTextGenInference(LLM):
|
||||
"""
|
||||
HuggingFace text generation API.
|
||||
|
||||
It generates text from a given prompt.
|
||||
To use, you should have the `text-generation` python package installed and
|
||||
a text-generation server running.
|
||||
|
||||
Attributes:
|
||||
- max_new_tokens: The maximum number of tokens to generate.
|
||||
- top_k: The number of top-k tokens to consider when generating text.
|
||||
- top_p: The cumulative probability threshold for generating text.
|
||||
- typical_p: The typical probability threshold for generating text.
|
||||
- temperature: The temperature to use when generating text.
|
||||
- repetition_penalty: The repetition penalty to use when generating text.
|
||||
- truncate: truncate inputs tokens to the given size
|
||||
- stop_sequences: A list of stop sequences to use when generating text.
|
||||
- seed: The seed to use when generating text.
|
||||
- inference_server_url: The URL of the inference server to use.
|
||||
- timeout: The timeout value in seconds to use while connecting to inference server.
|
||||
- server_kwargs: The keyword arguments to pass to the inference server.
|
||||
- client: The client object used to communicate with the inference server.
|
||||
- async_client: The async client object used to communicate with the server.
|
||||
|
||||
Methods:
|
||||
- _call: Generates text based on a given prompt and stop sequences.
|
||||
- _acall: Async generates text based on a given prompt and stop sequences.
|
||||
- _llm_type: Returns the type of LLM.
|
||||
- _default_params: Returns the default parameters for calling text generation
|
||||
inference API.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Basic Example (no streaming)
|
||||
llm = HuggingFaceTextGenInference(
|
||||
inference_server_url = "http://localhost:8010/",
|
||||
max_new_tokens = 512,
|
||||
top_k = 10,
|
||||
top_p = 0.95,
|
||||
typical_p = 0.95,
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
inference_server_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
@ -60,36 +40,59 @@ class HuggingFaceTextGenInference(LLM):
|
||||
|
||||
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
||||
llm = HuggingFaceTextGenInference(
|
||||
inference_server_url = "http://localhost:8010/",
|
||||
max_new_tokens = 512,
|
||||
top_k = 10,
|
||||
top_p = 0.95,
|
||||
typical_p = 0.95,
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
callbacks = callbacks,
|
||||
streaming = True
|
||||
inference_server_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
callbacks=callbacks,
|
||||
streaming=True
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
"""
|
||||
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of generated tokens"""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for
|
||||
top-k-filtering."""
|
||||
top_p: Optional[float] = 0.95
|
||||
"""If set to < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to `top_p` or higher are kept for generation."""
|
||||
typical_p: Optional[float] = 0.95
|
||||
"""Typical Decoding mass. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information."""
|
||||
temperature: float = 0.8
|
||||
"""The value used to module the logits distribution."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
|
||||
return_full_text: bool = False
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed"""
|
||||
inference_server_url: str = ""
|
||||
"""text-generation-inference instance base url"""
|
||||
timeout: int = 120
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Timeout in seconds"""
|
||||
streaming: bool = False
|
||||
"""Whether to generate a stream of tokens asynchronously"""
|
||||
do_sample: bool = False
|
||||
"""Activate logits sampling"""
|
||||
watermark: bool = False
|
||||
"""Watermarking with [A Watermark for Large Language Models]
|
||||
(https://arxiv.org/abs/2301.10226)"""
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||
client: Any
|
||||
async_client: Any
|
||||
|
||||
@ -98,6 +101,32 @@ class HuggingFaceTextGenInference(LLM):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that python package exists in environment."""
|
||||
@ -143,6 +172,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
"seed": self.seed,
|
||||
"do_sample": self.do_sample,
|
||||
"watermark": self.watermark,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
|
Loading…
Reference in New Issue
Block a user