mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
feat(llms): add model_kwargs to hf tgi (#10139)
@baskaryan Following what we discussed in #9724 and your suggestion, I've added a `model_kwargs` parameter to hf tgi.
This commit is contained in:
parent
e0f6ba08d6
commit
10e0431e48
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -7,89 +8,91 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceTextGenInference(LLM):
|
class HuggingFaceTextGenInference(LLM):
|
||||||
"""
|
"""
|
||||||
HuggingFace text generation API.
|
HuggingFace text generation API.
|
||||||
|
|
||||||
It generates text from a given prompt.
|
To use, you should have the `text-generation` python package installed and
|
||||||
|
a text-generation server running.
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- max_new_tokens: The maximum number of tokens to generate.
|
|
||||||
- top_k: The number of top-k tokens to consider when generating text.
|
|
||||||
- top_p: The cumulative probability threshold for generating text.
|
|
||||||
- typical_p: The typical probability threshold for generating text.
|
|
||||||
- temperature: The temperature to use when generating text.
|
|
||||||
- repetition_penalty: The repetition penalty to use when generating text.
|
|
||||||
- truncate: truncate inputs tokens to the given size
|
|
||||||
- stop_sequences: A list of stop sequences to use when generating text.
|
|
||||||
- seed: The seed to use when generating text.
|
|
||||||
- inference_server_url: The URL of the inference server to use.
|
|
||||||
- timeout: The timeout value in seconds to use while connecting to inference server.
|
|
||||||
- server_kwargs: The keyword arguments to pass to the inference server.
|
|
||||||
- client: The client object used to communicate with the inference server.
|
|
||||||
- async_client: The async client object used to communicate with the server.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
- _call: Generates text based on a given prompt and stop sequences.
|
|
||||||
- _acall: Async generates text based on a given prompt and stop sequences.
|
|
||||||
- _llm_type: Returns the type of LLM.
|
|
||||||
- _default_params: Returns the default parameters for calling text generation
|
|
||||||
inference API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Basic Example (no streaming)
|
# Basic Example (no streaming)
|
||||||
llm = HuggingFaceTextGenInference(
|
llm = HuggingFaceTextGenInference(
|
||||||
inference_server_url = "http://localhost:8010/",
|
inference_server_url="http://localhost:8010/",
|
||||||
max_new_tokens = 512,
|
max_new_tokens=512,
|
||||||
top_k = 10,
|
top_k=10,
|
||||||
top_p = 0.95,
|
top_p=0.95,
|
||||||
typical_p = 0.95,
|
typical_p=0.95,
|
||||||
temperature = 0.01,
|
temperature=0.01,
|
||||||
repetition_penalty = 1.03,
|
repetition_penalty=1.03,
|
||||||
)
|
)
|
||||||
print(llm("What is Deep Learning?"))
|
print(llm("What is Deep Learning?"))
|
||||||
|
|
||||||
# Streaming response example
|
# Streaming response example
|
||||||
from langchain.callbacks import streaming_stdout
|
from langchain.callbacks import streaming_stdout
|
||||||
|
|
||||||
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
||||||
llm = HuggingFaceTextGenInference(
|
llm = HuggingFaceTextGenInference(
|
||||||
inference_server_url = "http://localhost:8010/",
|
inference_server_url="http://localhost:8010/",
|
||||||
max_new_tokens = 512,
|
max_new_tokens=512,
|
||||||
top_k = 10,
|
top_k=10,
|
||||||
top_p = 0.95,
|
top_p=0.95,
|
||||||
typical_p = 0.95,
|
typical_p=0.95,
|
||||||
temperature = 0.01,
|
temperature=0.01,
|
||||||
repetition_penalty = 1.03,
|
repetition_penalty=1.03,
|
||||||
callbacks = callbacks,
|
callbacks=callbacks,
|
||||||
streaming = True
|
streaming=True
|
||||||
)
|
)
|
||||||
print(llm("What is Deep Learning?"))
|
print(llm("What is Deep Learning?"))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
|
"""Maximum number of generated tokens"""
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
"""The number of highest probability vocabulary tokens to keep for
|
||||||
|
top-k-filtering."""
|
||||||
top_p: Optional[float] = 0.95
|
top_p: Optional[float] = 0.95
|
||||||
|
"""If set to < 1, only the smallest set of most probable tokens with probabilities
|
||||||
|
that add up to `top_p` or higher are kept for generation."""
|
||||||
typical_p: Optional[float] = 0.95
|
typical_p: Optional[float] = 0.95
|
||||||
|
"""Typical Decoding mass. See [Typical Decoding for Natural Language
|
||||||
|
Generation](https://arxiv.org/abs/2202.00666) for more information."""
|
||||||
temperature: float = 0.8
|
temperature: float = 0.8
|
||||||
|
"""The value used to module the logits distribution."""
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
|
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
|
||||||
return_full_text: bool = False
|
return_full_text: bool = False
|
||||||
|
"""Whether to prepend the prompt to the generated text"""
|
||||||
truncate: Optional[int] = None
|
truncate: Optional[int] = None
|
||||||
|
"""Truncate inputs tokens to the given size"""
|
||||||
stop_sequences: List[str] = Field(default_factory=list)
|
stop_sequences: List[str] = Field(default_factory=list)
|
||||||
|
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
|
"""Random sampling seed"""
|
||||||
inference_server_url: str = ""
|
inference_server_url: str = ""
|
||||||
|
"""text-generation-inference instance base url"""
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
"""Timeout in seconds"""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
"""Whether to generate a stream of tokens asynchronously"""
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
|
"""Activate logits sampling"""
|
||||||
watermark: bool = False
|
watermark: bool = False
|
||||||
|
"""Watermarking with [A Watermark for Large Language Models]
|
||||||
|
(https://arxiv.org/abs/2301.10226)"""
|
||||||
|
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||||
client: Any
|
client: Any
|
||||||
async_client: Any
|
async_client: Any
|
||||||
|
|
||||||
@ -98,6 +101,32 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that python package exists in environment."""
|
"""Validate that python package exists in environment."""
|
||||||
@ -143,6 +172,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"do_sample": self.do_sample,
|
"do_sample": self.do_sample,
|
||||||
"watermark": self.watermark,
|
"watermark": self.watermark,
|
||||||
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _invocation_params(
|
def _invocation_params(
|
||||||
|
Loading…
Reference in New Issue
Block a user