forked from Archives/langchain
feat: Added class to support huggingface text generation inference server (#4447)
[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for generating text using LLMs. This pull request add support for self hosted Text Generation Inference servers. feature: #4280 --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
258c319855
commit
cf4c1394a2
@ -0,0 +1,77 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Huggingface TextGen Inference\n",
|
||||
"\n",
|
||||
"[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co/) to power LLMs api-inference widgets.\n",
|
||||
"\n",
|
||||
"This notebooks goes over how to use a self hosted LLM using `Text Generation Inference`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To use, you should have the `text_generation` python package installed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip3 install text_generation "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceTextGenInference(\n",
|
||||
" inference_server_url='http://localhost:8010/',\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=10,\n",
|
||||
" top_p=0.95,\n",
|
||||
" typical_p=0.95,\n",
|
||||
" temperature=0.01,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
")\n",
|
||||
"llm(\"What did foo say about bar?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -26,6 +26,7 @@ from langchain.llms import (
|
||||
ForefrontAI,
|
||||
GooseAI,
|
||||
HuggingFaceHub,
|
||||
HuggingFaceTextGenInference,
|
||||
LlamaCpp,
|
||||
Modal,
|
||||
OpenAI,
|
||||
@ -114,4 +115,5 @@ __all__ = [
|
||||
"QAWithSourcesChain",
|
||||
"PALChain",
|
||||
"LlamaCpp",
|
||||
"HuggingFaceTextGenInference",
|
||||
]
|
||||
|
@ -16,6 +16,7 @@ from langchain.llms.gpt4all import GPT4All
|
||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
||||
from langchain.llms.human import HumanInputLLM
|
||||
from langchain.llms.llamacpp import LlamaCpp
|
||||
from langchain.llms.modal import Modal
|
||||
@ -67,6 +68,7 @@ __all__ = [
|
||||
"RWKV",
|
||||
"PredictionGuard",
|
||||
"HumanInputLLM",
|
||||
"HuggingFaceTextGenInference",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@ -99,4 +101,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"stochasticai": StochasticAI,
|
||||
"writer": Writer,
|
||||
"rwkv": RWKV,
|
||||
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||
}
|
||||
|
118
langchain/llms/huggingface_text_gen_inference.py
Normal file
118
langchain/llms/huggingface_text_gen_inference.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Wrapper around Huggingface text generation inference API."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class HuggingFaceTextGenInference(LLM):
|
||||
"""
|
||||
HuggingFace text generation inference API.
|
||||
|
||||
This class is a wrapper around the HuggingFace text generation inference API.
|
||||
It is used to generate text from a given prompt.
|
||||
|
||||
Attributes:
|
||||
- max_new_tokens: The maximum number of tokens to generate.
|
||||
- top_k: The number of top-k tokens to consider when generating text.
|
||||
- top_p: The cumulative probability threshold for generating text.
|
||||
- typical_p: The typical probability threshold for generating text.
|
||||
- temperature: The temperature to use when generating text.
|
||||
- repetition_penalty: The repetition penalty to use when generating text.
|
||||
- stop_sequences: A list of stop sequences to use when generating text.
|
||||
- seed: The seed to use when generating text.
|
||||
- inference_server_url: The URL of the inference server to use.
|
||||
- timeout: The timeout value in seconds to use while connecting to inference server.
|
||||
- client: The client object used to communicate with the inference server.
|
||||
|
||||
Methods:
|
||||
- _call: Generates text based on a given prompt and stop sequences.
|
||||
- _llm_type: Returns the type of LLM.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
llm = HuggingFaceTextGenInference(
|
||||
inference_server_url = "http://localhost:8010/",
|
||||
max_new_tokens = 512,
|
||||
top_k = 10,
|
||||
top_p = 0.95,
|
||||
typical_p = 0.95,
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
)
|
||||
"""
|
||||
|
||||
max_new_tokens: int = 512
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = 0.95
|
||||
typical_p: Optional[float] = 0.95
|
||||
temperature: float = 0.8
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
seed: Optional[int] = None
|
||||
inference_server_url: str = ""
|
||||
timeout: int = 120
|
||||
client: Any
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that python package exists in environment."""
|
||||
|
||||
try:
|
||||
import text_generation
|
||||
|
||||
values["client"] = text_generation.Client(
|
||||
values["inference_server_url"], timeout=values["timeout"]
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import text_generation python package. "
|
||||
"Please install it with `pip install text_generation`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "hf_textgen_inference"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
if stop is None:
|
||||
stop = self.stop_sequences
|
||||
else:
|
||||
stop += self.stop_sequences
|
||||
|
||||
res = self.client.generate(
|
||||
prompt,
|
||||
stop_sequences=stop,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
typical_p=self.typical_p,
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
seed=self.seed,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
|
||||
return res.generated_text
|
Loading…
Reference in New Issue
Block a user