Add oobabooga/text-generation-webui support as a llm (#5997)

Add oobabooga/text-generation-webui support as an LLM. Currently,
supports using text-generation-webui's non-streaming API interface.
Allows users who already have text-gen running to use the same models
with langchain.

#### Before submitting

Simple usage, similar to existing LLM supported:

```
from langchain.llms import TextGen
llm = TextGen(model_url = "http://localhost:5000")
```
#### Who can review?

 @hwchase17 - project lead

---------

Co-authored-by: Hien Ngo <Hien.Ngo@adia.ae>
This commit is contained in:
lonestriker 2023-06-17 20:42:15 +04:00 committed by GitHub
parent 444ca3f669
commit 6f36f0f930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 301 additions and 0 deletions

View File

@ -0,0 +1,87 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# TextGen\n",
"\n",
"[GitHub:oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.\n",
"\n",
"This example goes over how to use LangChain to interact with LLM models via the `text-generation-webui` API integration.\n",
"\n",
"Please ensure that you have `text-generation-webui` configured and an LLM installed. Recommended installation via the [one-click installer appropriate](https://github.com/oobabooga/text-generation-webui#one-click-installers) for your OS.\n",
"\n",
"Once `text-generation-webui` is installed and confirmed working via the web interface, please enable the `api` option either through the web model configuration tab, or by adding the run-time arg `--api` to your start command."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set model_url and run the example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_url = \"http://localhost:5000\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import langchain\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.llms import TextGen\n",
"\n",
"langchain.debug = True\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"llm = TextGen(model_url=model_url)\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -42,6 +42,7 @@ from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
from langchain.llms.stochasticai import StochasticAI from langchain.llms.stochasticai import StochasticAI
from langchain.llms.textgen import TextGen
from langchain.llms.vertexai import VertexAI from langchain.llms.vertexai import VertexAI
from langchain.llms.writer import Writer from langchain.llms.writer import Writer
@ -64,6 +65,7 @@ __all__ = [
"GooseAI", "GooseAI",
"GPT4All", "GPT4All",
"LlamaCpp", "LlamaCpp",
"TextGen",
"Modal", "Modal",
"MosaicML", "MosaicML",
"NLPCloud", "NLPCloud",
@ -114,6 +116,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"huggingface_hub": HuggingFaceHub, "huggingface_hub": HuggingFaceHub,
"huggingface_endpoint": HuggingFaceEndpoint, "huggingface_endpoint": HuggingFaceEndpoint,
"llamacpp": LlamaCpp, "llamacpp": LlamaCpp,
"textgen": TextGen,
"modal": Modal, "modal": Modal,
"mosaic": MosaicML, "mosaic": MosaicML,
"sagemaker_endpoint": SagemakerEndpoint, "sagemaker_endpoint": SagemakerEndpoint,

211
langchain/llms/textgen.py Normal file
View File

@ -0,0 +1,211 @@
"""Wrapper around text-generation-webui."""
import logging
from typing import Any, Dict, List, Optional
import requests
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class TextGen(LLM):
"""Wrapper around the text-generation-webui model.
To use, you should have the text-generation-webui installed, a model loaded,
and --api added as a command-line option.
Suggested installation, use one-click installer for your OS:
https://github.com/oobabooga/text-generation-webui#one-click-installers
Paremeters below taken from text-generation-webui api example:
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:8500")
"""
model_url: str
"""The full URL to the textgen webui including http[s]://host:port """
max_new_tokens: Optional[int] = 250
"""The maximum number of tokens to generate."""
do_sample: bool = Field(True, alias="do_sample")
"""Do sample"""
temperature: Optional[float] = 1.3
"""Primary factor to control randomness of outputs. 0 = deterministic
(only the most likely token is used). Higher value = more randomness."""
top_p: Optional[float] = 0.1
"""If not set to 1, select tokens with probabilities adding up to less than this
number. Higher value = higher range of possible random results."""
typical_p: Optional[float] = 1
"""If not set to 1, select only tokens that are at least this much more likely to
appear than random tokens, given the prior text."""
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
"""Epsilon cutoff"""
eta_cutoff: Optional[float] = 0 # In units of 1e-4
"""ETA cutoff"""
repetition_penalty: Optional[float] = 1.18
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty,
higher value = less repetition, lower value = more repetition."""
top_k: Optional[float] = 40
"""Similar to top_p, but select instead only the top_k most likely tokens.
Higher value = higher range of possible random results."""
min_length: Optional[int] = 0
"""Minimum generation length in tokens."""
no_repeat_ngram_size: Optional[int] = 0
"""If not set to 0, specifies the length of token sets that are completely blocked
from repeating at all. Higher values = blocks larger phrases,
lower values = blocks words or letters from repeating.
Only 0 or high values are a good idea in most cases."""
num_beams: Optional[int] = 1
"""Number of beams"""
penalty_alpha: Optional[float] = 0
"""Penalty Alpha"""
length_penalty: Optional[float] = 1
"""Length Penalty"""
early_stopping: bool = Field(False, alias="early_stopping")
"""Early stopping"""
seed: int = Field(-1, alias="seed")
"""Seed (-1 for random)"""
add_bos_token: bool = Field(True, alias="add_bos_token")
"""Add the bos_token to the beginning of prompts.
Disabling this can make the replies more creative."""
truncation_length: Optional[int] = 2048
"""Truncate the prompt up to this length. The leftmost tokens are removed if
the prompt exceeds this length. Most models require this to be at most 2048."""
ban_eos_token: bool = Field(False, alias="ban_eos_token")
"""Ban the eos_token. Forces the model to never end the generation prematurely."""
skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
"""Skip special tokens. Some specific models need this unset."""
stopping_strings: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
streaming: bool = False
"""Whether to stream the results, token by token (currently unimplemented)."""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling textgen."""
return {
"max_new_tokens": self.max_new_tokens,
"do_sample": self.do_sample,
"temperature": self.temperature,
"top_p": self.top_p,
"typical_p": self.typical_p,
"epsilon_cutoff": self.epsilon_cutoff,
"eta_cutoff": self.eta_cutoff,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"min_length": self.min_length,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"num_beams": self.num_beams,
"penalty_alpha": self.penalty_alpha,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"seed": self.seed,
"add_bos_token": self.add_bos_token,
"truncation_length": self.truncation_length,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"stopping_strings": self.stopping_strings,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_url": self.model_url}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "textgen"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Performs sanity check, preparing paramaters in format needed by textgen.
Args:
stop (Optional[List[str]]): List of stop sequences for textgen.
Returns:
Dictionary containing the combined parameters.
"""
# Raise error if stop sequences are in both input and default params
# if self.stop and stop is not None:
if self.stopping_strings and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
params = self._default_params
# then sets it as configured, or default to an empty list:
params["stop"] = self.stopping_strings or stop or []
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
raise ValueError("`streaming` option currently unsupported.")
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()["results"][0]["text"]
print(prompt + result)
else:
print(f"ERROR: response: {response}")
result = ""
return result