langchain/libs/community/tests/integration_tests/embeddings/test_xinference.py
Liuww 332ffed393
community[patch]: Adopting the lighter-weight xinference_client (#21900)
While integrating the xinference_embedding, we observed that the
downloaded dependency package is quite substantial in size. With a focus
on resource optimization and efficiency, if the project requirements are
limited to its vector processing capabilities, we recommend migrating to
the xinference_client package. This package is more streamlined,
significantly reducing the storage space requirements of the project and
maintaining a feature focus, making it particularly suitable for
scenarios that demand lightweight integration. Such an approach not only
boosts deployment efficiency but also enhances the application's
maintainability, rendering it an optimal choice for our current context.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-05-20 22:05:09 +00:00

86 lines
2.4 KiB
Python

"""Test Xinference embeddings."""
import time
from typing import AsyncGenerator, Tuple
import pytest_asyncio
from langchain_community.embeddings import XinferenceEmbeddings
@pytest_asyncio.fixture
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
import xoscar as xo
from xinference.deploy.supervisor import start_supervisor_components
from xinference.deploy.utils import create_worker_actor_pool
from xinference.deploy.worker import start_worker_components
pool = await create_worker_actor_pool(
f"test://127.0.0.1:{xo.utils.get_next_port()}"
)
print(f"Pool running on localhost:{pool.external_address}") # noqa: T201
endpoint = await start_supervisor_components(
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
)
await start_worker_components(
address=pool.external_address, supervisor_address=pool.external_address
)
# wait for the api.
time.sleep(3)
async with pool:
yield endpoint, pool.external_address
def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
"""Test xinference embeddings for documents."""
from xinference.client import RESTfulClient
endpoint, _ = setup
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="vicuna-v1.3",
model_size_in_billions=7,
model_format="ggmlv3",
quantization="q4_0",
)
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
documents = ["foo bar", "bar foo"]
output = xinference.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 4096
def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
"""Test xinference embeddings for query."""
from xinference.client import RESTfulClient
endpoint, _ = setup
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
)
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
document = "foo bar"
output = xinference.embed_query(document)
assert len(output) == 4096
def test_xinference_embedding() -> None:
embedding_model = XinferenceEmbeddings(
server_url="http://xinference-hostname:9997", model_uid="foo"
)
embedding_model.embed_documents(
texts=["hello", "i'm trying to upgrade xinference embedding"]
)