From a616e19975796ff6e3cde24597ba90eee714d57a Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Mon, 7 Aug 2023 16:32:02 +0200 Subject: [PATCH] feat(llms): add support for vLLM (#8806) Hello langchain maintainers, this PR aims at integrating [vllm](https://vllm.readthedocs.io/en/latest/#) into langchain. This PR closes #8729. This feature clearly depends on `vllm`, but I've seen other models supported here depend on packages that are not included in the pyproject.toml (e.g. `gpt4all`, `text-generation`) so I thought it was the case for this as well. @hwchase17, @baskaryan --------- Co-authored-by: Harrison Chase --- docs/extras/integrations/llms/vllm.ipynb | 196 ++++++++++++++++++++++ libs/langchain/langchain/llms/__init__.py | 3 + libs/langchain/langchain/llms/vllm.py | 123 ++++++++++++++ 3 files changed, 322 insertions(+) create mode 100644 docs/extras/integrations/llms/vllm.ipynb create mode 100644 libs/langchain/langchain/llms/vllm.py diff --git a/docs/extras/integrations/llms/vllm.ipynb b/docs/extras/integrations/llms/vllm.ipynb new file mode 100644 index 0000000000..7e3c5068c6 --- /dev/null +++ b/docs/extras/integrations/llms/vllm.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "499c3142-2033-437d-a60a-731988ac6074", + "metadata": {}, + "source": [ + "# vLLM\n", + "\n", + "[vLLM](https://vllm.readthedocs.io/en/latest/index.html) is a fast and easy-to-use library for LLM inference and serving, offering:\n", + "* State-of-the-art serving throughput \n", + "* Efficient management of attention key and value memory with PagedAttention\n", + "* Continuous batching of incoming requests\n", + "* Optimized CUDA kernels\n", + "\n", + "This notebooks goes over how to use a LLM with langchain and vLLM.\n", + "\n", + "To use, you should have the `vllm` python package installed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8a3f2666-5c75-4797-967a-7915a247bf33", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install vllm -q" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "84e350f7-21f6-455b-b1f0-8b0116a2fd49", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 08-06 11:37:33 llm_engine.py:70] Initializing an LLM engine with config: model='mosaicml/mpt-7b', tokenizer='mosaicml/mpt-7b', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.bfloat16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n", + "INFO 08-06 11:37:41 llm_engine.py:196] # GPU blocks: 861, # CPU blocks: 512\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 2.00it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "What is the capital of France ? The capital of France is Paris.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from langchain.llms import VLLM\n", + "\n", + "llm = VLLM(model=\"mosaicml/mpt-7b\",\n", + " trust_remote_code=True, # mandatory for hf models\n", + " max_new_tokens=128,\n", + " top_k=10,\n", + " top_p=0.95,\n", + " temperature=0.8,\n", + ")\n", + "\n", + "print(llm(\"What is the capital of France ?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "94a3b41d-8329-4f8f-94f9-453d7f132214", + "metadata": {}, + "source": [ + "## Integrate the model in an LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5605b7a1-fa63-49c1-934d-8b4ef8d71dd5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processed prompts: 100%|██████████| 1/1 [00:01<00:00, 1.34s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "1. The first Pokemon game was released in 1996.\n", + "2. The president was Bill Clinton.\n", + "3. Clinton was president from 1993 to 2001.\n", + "4. The answer is Clinton.\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from langchain import PromptTemplate, LLMChain\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "question = \"Who was the US president in the year the first Pokemon game was released?\"\n", + "\n", + "print(llm_chain.run(question))" + ] + }, + { + "cell_type": "markdown", + "id": "56826aba-d08b-4838-8bfa-ca96e463b25d", + "metadata": {}, + "source": [ + "## Distributed Inference\n", + "\n", + "vLLM supports distributed tensor-parallel inference and serving. \n", + "\n", + "To run multi-GPU inference with the LLM class, set the `tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8c25c35-47b5-459d-9985-3cf546e9ac16", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import VLLM\n", + "\n", + "llm = VLLM(model=\"mosaicml/mpt-30b\",\n", + " tensor_parallel_size=4,\n", + " trust_remote_code=True, # mandatory for hf models\n", + ")\n", + "\n", + "llm(\"What is the future of AI?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda_pytorch_p310", + "language": "python", + "name": "conda_pytorch_p310" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index e22c8ad3b7..53612278c9 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -76,6 +76,7 @@ from langchain.llms.stochasticai import StochasticAI from langchain.llms.textgen import TextGen from langchain.llms.tongyi import Tongyi from langchain.llms.vertexai import VertexAI +from langchain.llms.vllm import VLLM from langchain.llms.writer import Writer from langchain.llms.xinference import Xinference @@ -139,6 +140,7 @@ __all__ = [ "StochasticAI", "Tongyi", "VertexAI", + "VLLM", "Writer", "OctoAIEndpoint", "Xinference", @@ -198,6 +200,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "vertexai": VertexAI, "openllm": OpenLLM, "openllm_client": OpenLLM, + "vllm": VLLM, "writer": Writer, "xinference": Xinference, } diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py new file mode 100644 index 0000000000..d02b6ff02f --- /dev/null +++ b/libs/langchain/langchain/llms/vllm.py @@ -0,0 +1,123 @@ +from typing import Any, Dict, List, Optional + +from pydantic import root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import BaseLLM +from langchain.schema.output import Generation, LLMResult + + +class VLLM(BaseLLM): + model: str = "" + """The name or path of a HuggingFace Transformers model.""" + + tensor_parallel_size: Optional[int] = 1 + """The number of GPUs to use for distributed execution with tensor parallelism.""" + + trust_remote_code: Optional[bool] = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + + n: int = 1 + """Number of output sequences to return for the given prompt.""" + + best_of: Optional[int] = None + """Number of output sequences that are generated from the prompt.""" + + presence_penalty: float = 0.0 + """Float that penalizes new tokens based on whether they appear in the + generated text so far""" + + frequency_penalty: float = 0.0 + """Float that penalizes new tokens based on their frequency in the + generated text so far""" + + temperature: float = 1.0 + """Float that controls the randomness of the sampling.""" + + top_p: float = 1.0 + """Float that controls the cumulative probability of the top tokens to consider.""" + + top_k: int = -1 + """Integer that controls the number of top tokens to consider.""" + + use_beam_search: bool = False + """Whether to use beam search instead of sampling.""" + + stop: Optional[List[str]] = None + """List of strings that stop the generation when they are generated.""" + + ignore_eos: bool = False + """Whether to ignore the EOS token and continue generating tokens after + the EOS token is generated.""" + + max_new_tokens: int = 512 + """Maximum number of tokens to generate per output sequence.""" + + client: Any #: :meta private: + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that python package exists in environment.""" + + try: + from vllm import LLM as VLLModel + except ImportError: + raise ImportError( + "Could not import vllm python package. " + "Please install it with `pip install vllm`." + ) + + values["client"] = VLLModel( + model=values["model"], + tensor_parallel_size=values["tensor_parallel_size"], + trust_remote_code=values["trust_remote_code"], + ) + + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling vllm.""" + return { + "n": self.n, + "best_of": self.best_of, + "max_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "stop": self.stop, + "ignore_eos": self.ignore_eos, + "use_beam_search": self.use_beam_search, + } + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + + from vllm import SamplingParams + + # build sampling parameters + params = {**self._default_params, **kwargs, "stop": stop} + sampling_params = SamplingParams(**params) + # call the model + outputs = self.client.generate(prompts, sampling_params) + + generations = [] + for output in outputs: + text = output.outputs[0].text + generations.append([Generation(text=text)]) + + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "vllm"