diff --git a/docs/modules/models/llms/integrations/textgen.ipynb b/docs/modules/models/llms/integrations/textgen.ipynb new file mode 100644 index 00000000..490e3a4b --- /dev/null +++ b/docs/modules/models/llms/integrations/textgen.ipynb @@ -0,0 +1,87 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TextGen\n", + "\n", + "[GitHub:oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.\n", + "\n", + "This example goes over how to use LangChain to interact with LLM models via the `text-generation-webui` API integration.\n", + "\n", + "Please ensure that you have `text-generation-webui` configured and an LLM installed. Recommended installation via the [one-click installer appropriate](https://github.com/oobabooga/text-generation-webui#one-click-installers) for your OS.\n", + "\n", + "Once `text-generation-webui` is installed and confirmed working via the web interface, please enable the `api` option either through the web model configuration tab, or by adding the run-time arg `--api` to your start command." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set model_url and run the example" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model_url = \"http://localhost:5000\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import langchain\n", + "from langchain import PromptTemplate, LLMChain\n", + "from langchain.llms import TextGen\n", + "\n", + "langchain.debug = True\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "llm = TextGen(model_url=model_url)\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n", + "\n", + "llm_chain.run(question)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 42d7c030..d7f0cda8 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -42,6 +42,7 @@ from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.stochasticai import StochasticAI +from langchain.llms.textgen import TextGen from langchain.llms.vertexai import VertexAI from langchain.llms.writer import Writer @@ -64,6 +65,7 @@ __all__ = [ "GooseAI", "GPT4All", "LlamaCpp", + "TextGen", "Modal", "MosaicML", "NLPCloud", @@ -114,6 +116,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "huggingface_hub": HuggingFaceHub, "huggingface_endpoint": HuggingFaceEndpoint, "llamacpp": LlamaCpp, + "textgen": TextGen, "modal": Modal, "mosaic": MosaicML, "sagemaker_endpoint": SagemakerEndpoint, diff --git a/langchain/llms/textgen.py b/langchain/llms/textgen.py new file mode 100644 index 00000000..29fa7ce0 --- /dev/null +++ b/langchain/llms/textgen.py @@ -0,0 +1,211 @@ +"""Wrapper around text-generation-webui.""" +import logging +from typing import Any, Dict, List, Optional + +import requests +from pydantic import Field + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + + +class TextGen(LLM): + """Wrapper around the text-generation-webui model. + + To use, you should have the text-generation-webui installed, a model loaded, + and --api added as a command-line option. + + Suggested installation, use one-click installer for your OS: + https://github.com/oobabooga/text-generation-webui#one-click-installers + + Paremeters below taken from text-generation-webui api example: + https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen(model_url="http://localhost:8500") + """ + + model_url: str + """The full URL to the textgen webui including http[s]://host:port """ + + max_new_tokens: Optional[int] = 250 + """The maximum number of tokens to generate.""" + + do_sample: bool = Field(True, alias="do_sample") + """Do sample""" + + temperature: Optional[float] = 1.3 + """Primary factor to control randomness of outputs. 0 = deterministic + (only the most likely token is used). Higher value = more randomness.""" + + top_p: Optional[float] = 0.1 + """If not set to 1, select tokens with probabilities adding up to less than this + number. Higher value = higher range of possible random results.""" + + typical_p: Optional[float] = 1 + """If not set to 1, select only tokens that are at least this much more likely to + appear than random tokens, given the prior text.""" + + epsilon_cutoff: Optional[float] = 0 # In units of 1e-4 + """Epsilon cutoff""" + + eta_cutoff: Optional[float] = 0 # In units of 1e-4 + """ETA cutoff""" + + repetition_penalty: Optional[float] = 1.18 + """Exponential penalty factor for repeating prior tokens. 1 means no penalty, + higher value = less repetition, lower value = more repetition.""" + + top_k: Optional[float] = 40 + """Similar to top_p, but select instead only the top_k most likely tokens. + Higher value = higher range of possible random results.""" + + min_length: Optional[int] = 0 + """Minimum generation length in tokens.""" + + no_repeat_ngram_size: Optional[int] = 0 + """If not set to 0, specifies the length of token sets that are completely blocked + from repeating at all. Higher values = blocks larger phrases, + lower values = blocks words or letters from repeating. + Only 0 or high values are a good idea in most cases.""" + + num_beams: Optional[int] = 1 + """Number of beams""" + + penalty_alpha: Optional[float] = 0 + """Penalty Alpha""" + + length_penalty: Optional[float] = 1 + """Length Penalty""" + + early_stopping: bool = Field(False, alias="early_stopping") + """Early stopping""" + + seed: int = Field(-1, alias="seed") + """Seed (-1 for random)""" + + add_bos_token: bool = Field(True, alias="add_bos_token") + """Add the bos_token to the beginning of prompts. + Disabling this can make the replies more creative.""" + + truncation_length: Optional[int] = 2048 + """Truncate the prompt up to this length. The leftmost tokens are removed if + the prompt exceeds this length. Most models require this to be at most 2048.""" + + ban_eos_token: bool = Field(False, alias="ban_eos_token") + """Ban the eos_token. Forces the model to never end the generation prematurely.""" + + skip_special_tokens: bool = Field(True, alias="skip_special_tokens") + """Skip special tokens. Some specific models need this unset.""" + + stopping_strings: Optional[List[str]] = [] + """A list of strings to stop generation when encountered.""" + + streaming: bool = False + """Whether to stream the results, token by token (currently unimplemented).""" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling textgen.""" + return { + "max_new_tokens": self.max_new_tokens, + "do_sample": self.do_sample, + "temperature": self.temperature, + "top_p": self.top_p, + "typical_p": self.typical_p, + "epsilon_cutoff": self.epsilon_cutoff, + "eta_cutoff": self.eta_cutoff, + "repetition_penalty": self.repetition_penalty, + "top_k": self.top_k, + "min_length": self.min_length, + "no_repeat_ngram_size": self.no_repeat_ngram_size, + "num_beams": self.num_beams, + "penalty_alpha": self.penalty_alpha, + "length_penalty": self.length_penalty, + "early_stopping": self.early_stopping, + "seed": self.seed, + "add_bos_token": self.add_bos_token, + "truncation_length": self.truncation_length, + "ban_eos_token": self.ban_eos_token, + "skip_special_tokens": self.skip_special_tokens, + "stopping_strings": self.stopping_strings, + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model_url": self.model_url}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "textgen" + + def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Performs sanity check, preparing paramaters in format needed by textgen. + + Args: + stop (Optional[List[str]]): List of stop sequences for textgen. + + Returns: + Dictionary containing the combined parameters. + """ + + # Raise error if stop sequences are in both input and default params + # if self.stop and stop is not None: + if self.stopping_strings and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + + params = self._default_params + + # then sets it as configured, or default to an empty list: + params["stop"] = self.stopping_strings or stop or [] + + return params + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the textgen web API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen(model_url="http://localhost:5000") + llm("Write a story about llamas.") + """ + if self.streaming: + raise ValueError("`streaming` option currently unsupported.") + + url = f"{self.model_url}/api/v1/generate" + params = self._get_parameters(stop) + request = params.copy() + request["prompt"] = prompt + response = requests.post(url, json=request) + + if response.status_code == 200: + result = response.json()["results"][0]["text"] + print(prompt + result) + else: + print(f"ERROR: response: {response}") + result = "" + + return result