feat(llms): add support for vLLM (#8806)

Hello langchain maintainers, 
this PR aims at integrating
[vllm](https://vllm.readthedocs.io/en/latest/#) into langchain. This PR
closes #8729.

This feature clearly depends on `vllm`, but I've seen other models
supported here depend on packages that are not included in the
pyproject.toml (e.g. `gpt4all`, `text-generation`) so I thought it was
the case for this as well.

@hwchase17, @baskaryan

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/8870/head
Massimiliano Pronesti 1 year ago committed by GitHub
parent 100d9ce4c7
commit a616e19975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "499c3142-2033-437d-a60a-731988ac6074",
"metadata": {},
"source": [
"# vLLM\n",
"\n",
"[vLLM](https://vllm.readthedocs.io/en/latest/index.html) is a fast and easy-to-use library for LLM inference and serving, offering:\n",
"* State-of-the-art serving throughput \n",
"* Efficient management of attention key and value memory with PagedAttention\n",
"* Continuous batching of incoming requests\n",
"* Optimized CUDA kernels\n",
"\n",
"This notebooks goes over how to use a LLM with langchain and vLLM.\n",
"\n",
"To use, you should have the `vllm` python package installed."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8a3f2666-5c75-4797-967a-7915a247bf33",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install vllm -q"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "84e350f7-21f6-455b-b1f0-8b0116a2fd49",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-06 11:37:33 llm_engine.py:70] Initializing an LLM engine with config: model='mosaicml/mpt-7b', tokenizer='mosaicml/mpt-7b', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.bfloat16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-06 11:37:41 llm_engine.py:196] # GPU blocks: 861, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 2.00it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"What is the capital of France ? The capital of France is Paris.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from langchain.llms import VLLM\n",
"\n",
"llm = VLLM(model=\"mosaicml/mpt-7b\",\n",
" trust_remote_code=True, # mandatory for hf models\n",
" max_new_tokens=128,\n",
" top_k=10,\n",
" top_p=0.95,\n",
" temperature=0.8,\n",
")\n",
"\n",
"print(llm(\"What is the capital of France ?\"))"
]
},
{
"cell_type": "markdown",
"id": "94a3b41d-8329-4f8f-94f9-453d7f132214",
"metadata": {},
"source": [
"## Integrate the model in an LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5605b7a1-fa63-49c1-934d-8b4ef8d71dd5",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:01<00:00, 1.34s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"1. The first Pokemon game was released in 1996.\n",
"2. The president was Bill Clinton.\n",
"3. Clinton was president from 1993 to 2001.\n",
"4. The answer is Clinton.\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from langchain import PromptTemplate, LLMChain\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"question = \"Who was the US president in the year the first Pokemon game was released?\"\n",
"\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "56826aba-d08b-4838-8bfa-ca96e463b25d",
"metadata": {},
"source": [
"## Distributed Inference\n",
"\n",
"vLLM supports distributed tensor-parallel inference and serving. \n",
"\n",
"To run multi-GPU inference with the LLM class, set the `tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8c25c35-47b5-459d-9985-3cf546e9ac16",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import VLLM\n",
"\n",
"llm = VLLM(model=\"mosaicml/mpt-30b\",\n",
" tensor_parallel_size=4,\n",
" trust_remote_code=True, # mandatory for hf models\n",
")\n",
"\n",
"llm(\"What is the future of AI?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p310",
"language": "python",
"name": "conda_pytorch_p310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -76,6 +76,7 @@ from langchain.llms.stochasticai import StochasticAI
from langchain.llms.textgen import TextGen
from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI
from langchain.llms.vllm import VLLM
from langchain.llms.writer import Writer
from langchain.llms.xinference import Xinference
@ -139,6 +140,7 @@ __all__ = [
"StochasticAI",
"Tongyi",
"VertexAI",
"VLLM",
"Writer",
"OctoAIEndpoint",
"Xinference",
@ -198,6 +200,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"vertexai": VertexAI,
"openllm": OpenLLM,
"openllm_client": OpenLLM,
"vllm": VLLM,
"writer": Writer,
"xinference": Xinference,
}

@ -0,0 +1,123 @@
from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import BaseLLM
from langchain.schema.output import Generation, LLMResult
class VLLM(BaseLLM):
model: str = ""
"""The name or path of a HuggingFace Transformers model."""
tensor_parallel_size: Optional[int] = 1
"""The number of GPUs to use for distributed execution with tensor parallelism."""
trust_remote_code: Optional[bool] = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
and tokenizer."""
n: int = 1
"""Number of output sequences to return for the given prompt."""
best_of: Optional[int] = None
"""Number of output sequences that are generated from the prompt."""
presence_penalty: float = 0.0
"""Float that penalizes new tokens based on whether they appear in the
generated text so far"""
frequency_penalty: float = 0.0
"""Float that penalizes new tokens based on their frequency in the
generated text so far"""
temperature: float = 1.0
"""Float that controls the randomness of the sampling."""
top_p: float = 1.0
"""Float that controls the cumulative probability of the top tokens to consider."""
top_k: int = -1
"""Integer that controls the number of top tokens to consider."""
use_beam_search: bool = False
"""Whether to use beam search instead of sampling."""
stop: Optional[List[str]] = None
"""List of strings that stop the generation when they are generated."""
ignore_eos: bool = False
"""Whether to ignore the EOS token and continue generating tokens after
the EOS token is generated."""
max_new_tokens: int = 512
"""Maximum number of tokens to generate per output sequence."""
client: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
from vllm import LLM as VLLModel
except ImportError:
raise ImportError(
"Could not import vllm python package. "
"Please install it with `pip install vllm`."
)
values["client"] = VLLModel(
model=values["model"],
tensor_parallel_size=values["tensor_parallel_size"],
trust_remote_code=values["trust_remote_code"],
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling vllm."""
return {
"n": self.n,
"best_of": self.best_of,
"max_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"stop": self.stop,
"ignore_eos": self.ignore_eos,
"use_beam_search": self.use_beam_search,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
from vllm import SamplingParams
# build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop}
sampling_params = SamplingParams(**params)
# call the model
outputs = self.client.generate(prompts, sampling_params)
generations = []
for output in outputs:
text = output.outputs[0].text
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm"
Loading…
Cancel
Save