forked from Archives/langchain
Add oobabooga/text-generation-webui support as a llm (#5997)
Add oobabooga/text-generation-webui support as an LLM. Currently, supports using text-generation-webui's non-streaming API interface. Allows users who already have text-gen running to use the same models with langchain. #### Before submitting Simple usage, similar to existing LLM supported: ``` from langchain.llms import TextGen llm = TextGen(model_url = "http://localhost:5000") ``` #### Who can review? @hwchase17 - project lead --------- Co-authored-by: Hien Ngo <Hien.Ngo@adia.ae>
This commit is contained in:
parent
444ca3f669
commit
6f36f0f930
87
docs/modules/models/llms/integrations/textgen.ipynb
Normal file
87
docs/modules/models/llms/integrations/textgen.ipynb
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# TextGen\n",
|
||||||
|
"\n",
|
||||||
|
"[GitHub:oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.\n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain to interact with LLM models via the `text-generation-webui` API integration.\n",
|
||||||
|
"\n",
|
||||||
|
"Please ensure that you have `text-generation-webui` configured and an LLM installed. Recommended installation via the [one-click installer appropriate](https://github.com/oobabooga/text-generation-webui#one-click-installers) for your OS.\n",
|
||||||
|
"\n",
|
||||||
|
"Once `text-generation-webui` is installed and confirmed working via the web interface, please enable the `api` option either through the web model configuration tab, or by adding the run-time arg `--api` to your start command."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set model_url and run the example"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_url = \"http://localhost:5000\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import langchain\n",
|
||||||
|
"from langchain import PromptTemplate, LLMChain\n",
|
||||||
|
"from langchain.llms import TextGen\n",
|
||||||
|
"\n",
|
||||||
|
"langchain.debug = True\n",
|
||||||
|
"\n",
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
|
||||||
|
"llm = TextGen(model_url=model_url)\n",
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||||
|
"question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(question)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -42,6 +42,7 @@ from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
|||||||
from langchain.llms.self_hosted import SelfHostedPipeline
|
from langchain.llms.self_hosted import SelfHostedPipeline
|
||||||
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
|
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
|
||||||
from langchain.llms.stochasticai import StochasticAI
|
from langchain.llms.stochasticai import StochasticAI
|
||||||
|
from langchain.llms.textgen import TextGen
|
||||||
from langchain.llms.vertexai import VertexAI
|
from langchain.llms.vertexai import VertexAI
|
||||||
from langchain.llms.writer import Writer
|
from langchain.llms.writer import Writer
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ __all__ = [
|
|||||||
"GooseAI",
|
"GooseAI",
|
||||||
"GPT4All",
|
"GPT4All",
|
||||||
"LlamaCpp",
|
"LlamaCpp",
|
||||||
|
"TextGen",
|
||||||
"Modal",
|
"Modal",
|
||||||
"MosaicML",
|
"MosaicML",
|
||||||
"NLPCloud",
|
"NLPCloud",
|
||||||
@ -114,6 +116,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"huggingface_hub": HuggingFaceHub,
|
"huggingface_hub": HuggingFaceHub,
|
||||||
"huggingface_endpoint": HuggingFaceEndpoint,
|
"huggingface_endpoint": HuggingFaceEndpoint,
|
||||||
"llamacpp": LlamaCpp,
|
"llamacpp": LlamaCpp,
|
||||||
|
"textgen": TextGen,
|
||||||
"modal": Modal,
|
"modal": Modal,
|
||||||
"mosaic": MosaicML,
|
"mosaic": MosaicML,
|
||||||
"sagemaker_endpoint": SagemakerEndpoint,
|
"sagemaker_endpoint": SagemakerEndpoint,
|
||||||
|
211
langchain/llms/textgen.py
Normal file
211
langchain/llms/textgen.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
"""Wrapper around text-generation-webui."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextGen(LLM):
|
||||||
|
"""Wrapper around the text-generation-webui model.
|
||||||
|
|
||||||
|
To use, you should have the text-generation-webui installed, a model loaded,
|
||||||
|
and --api added as a command-line option.
|
||||||
|
|
||||||
|
Suggested installation, use one-click installer for your OS:
|
||||||
|
https://github.com/oobabooga/text-generation-webui#one-click-installers
|
||||||
|
|
||||||
|
Paremeters below taken from text-generation-webui api example:
|
||||||
|
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import TextGen
|
||||||
|
llm = TextGen(model_url="http://localhost:8500")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_url: str
|
||||||
|
"""The full URL to the textgen webui including http[s]://host:port """
|
||||||
|
|
||||||
|
max_new_tokens: Optional[int] = 250
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
|
|
||||||
|
do_sample: bool = Field(True, alias="do_sample")
|
||||||
|
"""Do sample"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = 1.3
|
||||||
|
"""Primary factor to control randomness of outputs. 0 = deterministic
|
||||||
|
(only the most likely token is used). Higher value = more randomness."""
|
||||||
|
|
||||||
|
top_p: Optional[float] = 0.1
|
||||||
|
"""If not set to 1, select tokens with probabilities adding up to less than this
|
||||||
|
number. Higher value = higher range of possible random results."""
|
||||||
|
|
||||||
|
typical_p: Optional[float] = 1
|
||||||
|
"""If not set to 1, select only tokens that are at least this much more likely to
|
||||||
|
appear than random tokens, given the prior text."""
|
||||||
|
|
||||||
|
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
|
||||||
|
"""Epsilon cutoff"""
|
||||||
|
|
||||||
|
eta_cutoff: Optional[float] = 0 # In units of 1e-4
|
||||||
|
"""ETA cutoff"""
|
||||||
|
|
||||||
|
repetition_penalty: Optional[float] = 1.18
|
||||||
|
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty,
|
||||||
|
higher value = less repetition, lower value = more repetition."""
|
||||||
|
|
||||||
|
top_k: Optional[float] = 40
|
||||||
|
"""Similar to top_p, but select instead only the top_k most likely tokens.
|
||||||
|
Higher value = higher range of possible random results."""
|
||||||
|
|
||||||
|
min_length: Optional[int] = 0
|
||||||
|
"""Minimum generation length in tokens."""
|
||||||
|
|
||||||
|
no_repeat_ngram_size: Optional[int] = 0
|
||||||
|
"""If not set to 0, specifies the length of token sets that are completely blocked
|
||||||
|
from repeating at all. Higher values = blocks larger phrases,
|
||||||
|
lower values = blocks words or letters from repeating.
|
||||||
|
Only 0 or high values are a good idea in most cases."""
|
||||||
|
|
||||||
|
num_beams: Optional[int] = 1
|
||||||
|
"""Number of beams"""
|
||||||
|
|
||||||
|
penalty_alpha: Optional[float] = 0
|
||||||
|
"""Penalty Alpha"""
|
||||||
|
|
||||||
|
length_penalty: Optional[float] = 1
|
||||||
|
"""Length Penalty"""
|
||||||
|
|
||||||
|
early_stopping: bool = Field(False, alias="early_stopping")
|
||||||
|
"""Early stopping"""
|
||||||
|
|
||||||
|
seed: int = Field(-1, alias="seed")
|
||||||
|
"""Seed (-1 for random)"""
|
||||||
|
|
||||||
|
add_bos_token: bool = Field(True, alias="add_bos_token")
|
||||||
|
"""Add the bos_token to the beginning of prompts.
|
||||||
|
Disabling this can make the replies more creative."""
|
||||||
|
|
||||||
|
truncation_length: Optional[int] = 2048
|
||||||
|
"""Truncate the prompt up to this length. The leftmost tokens are removed if
|
||||||
|
the prompt exceeds this length. Most models require this to be at most 2048."""
|
||||||
|
|
||||||
|
ban_eos_token: bool = Field(False, alias="ban_eos_token")
|
||||||
|
"""Ban the eos_token. Forces the model to never end the generation prematurely."""
|
||||||
|
|
||||||
|
skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
|
||||||
|
"""Skip special tokens. Some specific models need this unset."""
|
||||||
|
|
||||||
|
stopping_strings: Optional[List[str]] = []
|
||||||
|
"""A list of strings to stop generation when encountered."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results, token by token (currently unimplemented)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling textgen."""
|
||||||
|
return {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"typical_p": self.typical_p,
|
||||||
|
"epsilon_cutoff": self.epsilon_cutoff,
|
||||||
|
"eta_cutoff": self.eta_cutoff,
|
||||||
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"min_length": self.min_length,
|
||||||
|
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
||||||
|
"num_beams": self.num_beams,
|
||||||
|
"penalty_alpha": self.penalty_alpha,
|
||||||
|
"length_penalty": self.length_penalty,
|
||||||
|
"early_stopping": self.early_stopping,
|
||||||
|
"seed": self.seed,
|
||||||
|
"add_bos_token": self.add_bos_token,
|
||||||
|
"truncation_length": self.truncation_length,
|
||||||
|
"ban_eos_token": self.ban_eos_token,
|
||||||
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
|
"stopping_strings": self.stopping_strings,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model_url": self.model_url}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "textgen"
|
||||||
|
|
||||||
|
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Performs sanity check, preparing paramaters in format needed by textgen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stop (Optional[List[str]]): List of stop sequences for textgen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the combined parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Raise error if stop sequences are in both input and default params
|
||||||
|
# if self.stop and stop is not None:
|
||||||
|
if self.stopping_strings and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
|
||||||
|
params = self._default_params
|
||||||
|
|
||||||
|
# then sets it as configured, or default to an empty list:
|
||||||
|
params["stop"] = self.stopping_strings or stop or []
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call the textgen web API and return the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to use for generation.
|
||||||
|
stop: A list of strings to stop generation when encountered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated text.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import TextGen
|
||||||
|
llm = TextGen(model_url="http://localhost:5000")
|
||||||
|
llm("Write a story about llamas.")
|
||||||
|
"""
|
||||||
|
if self.streaming:
|
||||||
|
raise ValueError("`streaming` option currently unsupported.")
|
||||||
|
|
||||||
|
url = f"{self.model_url}/api/v1/generate"
|
||||||
|
params = self._get_parameters(stop)
|
||||||
|
request = params.copy()
|
||||||
|
request["prompt"] = prompt
|
||||||
|
response = requests.post(url, json=request)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()["results"][0]["text"]
|
||||||
|
print(prompt + result)
|
||||||
|
else:
|
||||||
|
print(f"ERROR: response: {response}")
|
||||||
|
result = ""
|
||||||
|
|
||||||
|
return result
|
Loading…
Reference in New Issue
Block a user