mirror of https://github.com/hwchase17/langchain
FEAT: Integrate Xinference LLMs and Embeddings (#8171)
- [Xorbits Inference(Xinference)](https://github.com/xorbitsai/inference) is a powerful and versatile library designed to serve language, speech recognition, and multimodal models. Xinference supports a variety of GGML-compatible models including chatglm, whisper, and vicuna, and utilizes heterogeneous hardware and a distributed architecture for seamless cross-device and cross-server model deployment. - This PR integrates Xinference models and Xinference embeddings into LangChain. - Dependencies: To install the depenedencies for this integration, run `pip install "xinference[all]"` - Example Usage: To start a local instance of Xinference, run `xinference`. To deploy Xinference in a distributed cluster, first start an Xinference supervisor using `xinference-supervisor`: `xinference-supervisor -H "${supervisor_host}"` Then, start the Xinference workers using `xinference-worker` on each server you want to run them on. `xinference-worker -e "http://${supervisor_host}:9997"` To use Xinference with LangChain, you also need to launch a model. You can use command line interface (CLI) to do so. Fo example: `xinference launch -n vicuna-v1.3 -f ggmlv3 -q q4_0`. This launches a model named vicuna-v1.3 with `model_format="ggmlv3"` and `quantization="q4_0"`. A model UID is returned for you to use. Now you can use Xinference with LangChain: ```python from langchain.llms import Xinference llm = Xinference( server_url="http://0.0.0.0:9997", # suppose the supervisor_host is "0.0.0.0" model_uid = {model_uid} # model UID returned from launching a model ) llm( prompt="Q: where can we visit in the capital of France? A:", generate_config={"max_tokens": 1024}, ) ``` You can also use RESTful client to launch a model: ```python from xinference.client import RESTfulClient client = RESTfulClient("http://0.0.0.0:9997") model_uid = client.launch_model(model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0") ``` The following code block demonstrates how to use Xinference embeddings with LangChain: ```python from langchain.embeddings import XinferenceEmbeddings xinference = XinferenceEmbeddings( server_url="http://0.0.0.0:9997", model_uid = model_uid ) ``` ```python query_result = xinference.embed_query("This is a test query") ``` ```python doc_result = xinference.embed_documents(["text A", "text B"]) ``` Xinference is still under rapid development. Feel free to [join our Slack community](https://xorbitsio.slack.com/join/shared_invite/zt-1z3zsm9ep-87yI9YZ_B79HLB2ccTq4WA) to get the latest updates! - Request for review: @hwchase17, @baskaryan - Twitter handle: https://twitter.com/Xorbitsio --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/8394/head^2
parent
877d384bc9
commit
1efb9bae5f
@ -0,0 +1,176 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Xorbits Inference (Xinference)\n",
|
||||
"\n",
|
||||
"[Xinference](https://github.com/xorbitsai/inference) is a powerful and versatile library designed to serve LLMs, \n",
|
||||
"speech recognition models, and multimodal models, even on your laptop. It supports a variety of models compatible with GGML, such as chatglm, baichuan, whisper, vicuna, orca, and many others. This notebook demonstrates how to use Xinference with LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation\n",
|
||||
"\n",
|
||||
"Install `Xinference` through PyPI:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install \"xinference[all]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Deploy Xinference Locally or in a Distributed Cluster.\n",
|
||||
"\n",
|
||||
"For local deployment, run `xinference`. \n",
|
||||
"\n",
|
||||
"To deploy Xinference in a cluster, first start an Xinference supervisor using the `xinference-supervisor`. You can also use the option -p to specify the port and -H to specify the host. The default port is 9997.\n",
|
||||
"\n",
|
||||
"Then, start the Xinference workers using `xinference-worker` on each server you want to run them on. \n",
|
||||
"\n",
|
||||
"You can consult the README file from [Xinference](https://github.com/xorbitsai/inference) for more information.\n",
|
||||
"## Wrapper\n",
|
||||
"\n",
|
||||
"To use Xinference with LangChain, you need to first launch a model. You can use command line interface (CLI) to do so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model uid: 7167b2b0-2a04-11ee-83f0-d29396a3f064\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!xinference launch -n vicuna-v1.3 -f ggmlv3 -q q4_0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A model UID is returned for you to use. Now you can use Xinference with LangChain:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' You can visit the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, and many other historical sites in Paris, the capital of France.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import Xinference\n",
|
||||
"\n",
|
||||
"llm = Xinference(\n",
|
||||
" server_url=\"http://0.0.0.0:9997\",\n",
|
||||
" model_uid = \"7167b2b0-2a04-11ee-83f0-d29396a3f064\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm(\n",
|
||||
" prompt=\"Q: where can we visit in the capital of France? A:\",\n",
|
||||
" generate_config={\"max_tokens\": 1024, \"stream\": True},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Integrate with a LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"A: You can visit many places in Paris, such as the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, the Champs-Elysées, Montmartre, Sacré-Cœur, and the Palace of Versailles.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import PromptTemplate, LLMChain\n",
|
||||
"\n",
|
||||
"template = \"Where can we visit in the capital of {country}?\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"country\"])\n",
|
||||
"\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"\n",
|
||||
"generated = llm_chain.run(country=\"France\")\n",
|
||||
"print(generated)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Lastly, terminate the model when you do not need to use it:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!xinference terminate --model-uid \"7167b2b0-2a04-11ee-83f0-d29396a3f064\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "myenv3.9",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
# Xorbits Inference (Xinference)
|
||||
|
||||
This page demonstrates how to use [Xinference](https://github.com/xorbitsai/inference)
|
||||
with LangChain.
|
||||
|
||||
`Xinference` is a powerful and versatile library designed to serve LLMs,
|
||||
speech recognition models, and multimodal models, even on your laptop.
|
||||
With Xorbits Inference, you can effortlessly deploy and serve your or
|
||||
state-of-the-art built-in models using just a single command.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Xinference can be installed via pip from PyPI:
|
||||
|
||||
```bash
|
||||
pip install "xinference[all]"
|
||||
```
|
||||
|
||||
## LLM
|
||||
|
||||
Xinference supports various models compatible with GGML, including chatglm, baichuan, whisper,
|
||||
vicuna, and orca. To view the builtin models, run the command:
|
||||
|
||||
```bash
|
||||
xinference list --all
|
||||
```
|
||||
|
||||
|
||||
### Wrapper for Xinference
|
||||
|
||||
You can start a local instance of Xinference by running:
|
||||
|
||||
```bash
|
||||
xinference
|
||||
```
|
||||
|
||||
You can also deploy Xinference in a distributed cluster. To do so, first start an Xinference supervisor
|
||||
on the server you want to run it:
|
||||
|
||||
```bash
|
||||
xinference-supervisor -H "${supervisor_host}"
|
||||
```
|
||||
|
||||
|
||||
Then, start the Xinference workers on each of the other servers where you want to run them on:
|
||||
|
||||
```bash
|
||||
xinference-worker -e "http://${supervisor_host}:9997"
|
||||
```
|
||||
|
||||
You can also start a local instance of Xinference by running:
|
||||
|
||||
```bash
|
||||
xinference
|
||||
```
|
||||
|
||||
Once Xinference is running, an endpoint will be accessible for model management via CLI or
|
||||
Xinference client.
|
||||
|
||||
For local deployment, the endpoint will be http://localhost:9997.
|
||||
|
||||
|
||||
For cluster deployment, the endpoint will be http://${supervisor_host}:9997.
|
||||
|
||||
|
||||
Then, you need to launch a model. You can specify the model names and other attributes
|
||||
including model_size_in_billions and quantization. You can use command line interface (CLI) to
|
||||
do it. For example,
|
||||
|
||||
```bash
|
||||
xinference launch -n orca -s 3 -q q4_0
|
||||
```
|
||||
|
||||
A model uid will be returned.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
from langchain.llms import Xinference
|
||||
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
For more information and detailed examples, refer to the
|
||||
[example notebook for xinference](../modules/models/llms/integrations/xinference.ipynb)
|
||||
|
||||
### Embeddings
|
||||
|
||||
Xinference also supports embedding queries and documents. See
|
||||
[example notebook for xinference embeddings](../modules/data_connection/text_embedding/integrations/xinference.ipynb)
|
||||
for a more detailed demo.
|
@ -0,0 +1,144 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Xorbits inference (Xinference)\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use Xinference embeddings within LangChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation\n",
|
||||
"\n",
|
||||
"Install `Xinference` through PyPI:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install \"xinference[all]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Deploy Xinference Locally or in a Distributed Cluster.\n",
|
||||
"\n",
|
||||
"For local deployment, run `xinference`. \n",
|
||||
"\n",
|
||||
"To deploy Xinference in a cluster, first start an Xinference supervisor using the `xinference-supervisor`. You can also use the option -p to specify the port and -H to specify the host. The default port is 9997.\n",
|
||||
"\n",
|
||||
"Then, start the Xinference workers using `xinference-worker` on each server you want to run them on. \n",
|
||||
"\n",
|
||||
"You can consult the README file from [Xinference](https://github.com/xorbitsai/inference) for more information.\n",
|
||||
"\n",
|
||||
"## Wrapper\n",
|
||||
"\n",
|
||||
"To use Xinference with LangChain, you need to first launch a model. You can use command line interface (CLI) to do so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model uid: 915845ee-2a04-11ee-8ed4-d29396a3f064\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!xinference launch -n vicuna-v1.3 -f ggmlv3 -q q4_0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A model UID is returned for you to use. Now you can use Xinference embeddings with LangChain:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import XinferenceEmbeddings\n",
|
||||
"\n",
|
||||
"xinference = XinferenceEmbeddings(\n",
|
||||
" server_url=\"http://0.0.0.0:9997\",\n",
|
||||
" model_uid = \"915845ee-2a04-11ee-8ed4-d29396a3f064\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = xinference.embed_query(\"This is a test query\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_result = xinference.embed_documents([\"text A\", \"text B\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Lastly, terminate the model when you do not need to use it:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!xinference terminate --model-uid \"915845ee-2a04-11ee-8ed4-d29396a3f064\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
"""Wrapper around Xinference embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
|
||||
"""Wrapper around xinference embedding models.
|
||||
To use, you should have the xinference library installed:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install xinference
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
Starting the supervisor:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
Starting the worker:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then you can use Xinference Embedding with LangChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
xinference = XinferenceEmbeddings(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import RESTfulClient from xinference. Please install it"
|
||||
" with `pip install xinference`."
|
||||
) from e
|
||||
|
||||
super().__init__()
|
||||
|
||||
if server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.server_url = server_url
|
||||
|
||||
self.model_uid = model_uid
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using Xinference.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embeddings = [
|
||||
model.create_embedding(text)["data"][0]["embedding"] for text in texts
|
||||
]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query of documents using Xinference.
|
||||
Args:
|
||||
text: The text to embed.
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embedding_res = model.create_embedding(text)
|
||||
|
||||
embedding = embedding_res["data"][0]["embedding"]
|
||||
|
||||
return list(map(float, embedding))
|
@ -0,0 +1,185 @@
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Mapping, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
from xinference.model.llm.core import LlamaCppGenerateConfig
|
||||
|
||||
|
||||
class Xinference(LLM):
|
||||
"""Wrapper for accessing Xinference's large-scale model inference service.
|
||||
To use, you should have the xinference library installed:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "xinference[all]"
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
Starting the supervisor:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
|
||||
Starting the worker:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then, you can use Xinference with LangChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Xinference
|
||||
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
To view all the supported builtin models, run:
|
||||
.. code-block:: bash
|
||||
$ xinference list --all
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import RESTfulClient from xinference. Please install it"
|
||||
" with `pip install xinference`."
|
||||
) from e
|
||||
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
}
|
||||
)
|
||||
|
||||
if self.server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "xinference"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional["LlamaCppGenerateConfig"] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,74 @@
|
||||
"""Test Xinference embeddings."""
|
||||
import time
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
|
||||
import xoscar as xo
|
||||
from xinference.deploy.supervisor import start_supervisor_components
|
||||
from xinference.deploy.utils import create_worker_actor_pool
|
||||
from xinference.deploy.worker import start_worker_components
|
||||
|
||||
pool = await create_worker_actor_pool(
|
||||
f"test://127.0.0.1:{xo.utils.get_next_port()}"
|
||||
)
|
||||
print(f"Pool running on localhost:{pool.external_address}")
|
||||
|
||||
endpoint = await start_supervisor_components(
|
||||
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
|
||||
)
|
||||
await start_worker_components(
|
||||
address=pool.external_address, supervisor_address=pool.external_address
|
||||
)
|
||||
|
||||
# wait for the api.
|
||||
time.sleep(3)
|
||||
async with pool:
|
||||
yield endpoint, pool.external_address
|
||||
|
||||
|
||||
def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
|
||||
"""Test xinference embeddings for documents."""
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3",
|
||||
model_size_in_billions=7,
|
||||
model_format="ggmlv3",
|
||||
quantization="q4_0",
|
||||
)
|
||||
|
||||
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
documents = ["foo bar", "bar foo"]
|
||||
output = xinference.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 4096
|
||||
|
||||
|
||||
def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
|
||||
"""Test xinference embeddings for query."""
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
|
||||
)
|
||||
|
||||
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
document = "foo bar"
|
||||
output = xinference.embed_query(document)
|
||||
assert len(output) == 4096
|
@ -0,0 +1,57 @@
|
||||
"""Test Xinference wrapper."""
|
||||
import time
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain.llms import Xinference
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
|
||||
import xoscar as xo
|
||||
from xinference.deploy.supervisor import start_supervisor_components
|
||||
from xinference.deploy.utils import create_worker_actor_pool
|
||||
from xinference.deploy.worker import start_worker_components
|
||||
|
||||
pool = await create_worker_actor_pool(
|
||||
f"test://127.0.0.1:{xo.utils.get_next_port()}"
|
||||
)
|
||||
print(f"Pool running on localhost:{pool.external_address}")
|
||||
|
||||
endpoint = await start_supervisor_components(
|
||||
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
|
||||
)
|
||||
await start_worker_components(
|
||||
address=pool.external_address, supervisor_address=pool.external_address
|
||||
)
|
||||
|
||||
# wait for the api.
|
||||
time.sleep(3)
|
||||
async with pool:
|
||||
yield endpoint, pool.external_address
|
||||
|
||||
|
||||
def test_xinference_llm_(setup: Tuple[str, str]) -> None:
|
||||
from xinference.client import RESTfulClient
|
||||
|
||||
endpoint, _ = setup
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
model_uid = client.launch_model(
|
||||
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
|
||||
)
|
||||
|
||||
llm = Xinference(server_url=endpoint, model_uid=model_uid)
|
||||
|
||||
answer = llm(prompt="Q: What food can we try in the capital of France? A:")
|
||||
|
||||
assert isinstance(answer, str)
|
||||
|
||||
answer = llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
assert isinstance(answer, str)
|
Loading…
Reference in New Issue