feat(llms): support vLLM's OpenAI-compatible server (#9179)

This PR aims at supporting [vLLM's OpenAI-compatible server
feature](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server),
i.e. allowing to call vLLM's LLMs like if they were OpenAI's.

I've also udpated the related notebook providing an example usage. At
the moment, vLLM only supports the `Completion` API.
pull/9192/head
Massimiliano Pronesti 1 year ago committed by GitHub
parent 621da3c164
commit d95eeaedbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -170,6 +170,51 @@
"\n",
"llm(\"What is the future of AI?\")"
]
},
{
"cell_type": "markdown",
"id": "64e89be0-6ad7-43a8-9dac-1324dcd4e851",
"metadata": {
"tags": []
},
"source": [
"## OpenAI-Compatible Server\n",
"\n",
"vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.\n",
"\n",
"This server can be queried in the same format as OpenAI API.\n",
"\n",
"### OpenAI-Compatible Completion"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c3cbc428-0bb8-422a-913e-1c6fef8b89d4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" a city that is filled with history, ancient buildings, and art around every corner\n"
]
}
],
"source": [
"from langchain.llms import VLLMOpenAI\n",
"\n",
"\n",
"llm = VLLMOpenAI(\n",
" openai_api_key=\"EMPTY\",\n",
" openai_api_base=\"http://localhost:8000/v1\",\n",
" model_name=\"tiiuae/falcon-7b\",\n",
" model_kwargs={\"stop\": [\".\"]}\n",
")\n",
"print(llm(\"Rome is\"))"
]
}
],
"metadata": {

@ -80,7 +80,7 @@ from langchain.llms.textgen import TextGen
from langchain.llms.titan_takeoff import TitanTakeoff
from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI
from langchain.llms.vllm import VLLM
from langchain.llms.vllm import VLLM, VLLMOpenAI
from langchain.llms.writer import Writer
from langchain.llms.xinference import Xinference
@ -149,6 +149,7 @@ __all__ = [
"Tongyi",
"VertexAI",
"VLLM",
"VLLMOpenAI",
"Writer",
"OctoAIEndpoint",
"Xinference",
@ -213,6 +214,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"openllm": OpenLLM,
"openllm_client": OpenLLM,
"vllm": VLLM,
"vllm_openai": VLLMOpenAI,
"writer": Writer,
"xinference": Xinference,
}

@ -4,6 +4,7 @@ from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import BaseLLM
from langchain.llms.openai import BaseOpenAI
from langchain.schema.output import Generation, LLMResult
@ -127,3 +128,27 @@ class VLLM(BaseLLM):
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm"
class VLLMOpenAI(BaseOpenAI):
"""vLLM OpenAI-compatible API client"""
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
}
return {
"model": self.model_name,
**openai_creds,
**self._default_params,
"logit_bias": None,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm-openai"

Loading…
Cancel
Save