mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
332ffed393
While integrating the xinference_embedding, we observed that the downloaded dependency package is quite substantial in size. With a focus on resource optimization and efficiency, if the project requirements are limited to its vector processing capabilities, we recommend migrating to the xinference_client package. This package is more streamlined, significantly reducing the storage space requirements of the project and maintaining a feature focus, making it particularly suitable for scenarios that demand lightweight integration. Such an approach not only boosts deployment efficiency but also enhances the application's maintainability, rendering it an optimal choice for our current context. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
"""Test Xinference embeddings."""
|
|
|
|
import time
|
|
from typing import AsyncGenerator, Tuple
|
|
|
|
import pytest_asyncio
|
|
|
|
from langchain_community.embeddings import XinferenceEmbeddings
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def setup() -> AsyncGenerator[Tuple[str, str], None]:
|
|
import xoscar as xo
|
|
from xinference.deploy.supervisor import start_supervisor_components
|
|
from xinference.deploy.utils import create_worker_actor_pool
|
|
from xinference.deploy.worker import start_worker_components
|
|
|
|
pool = await create_worker_actor_pool(
|
|
f"test://127.0.0.1:{xo.utils.get_next_port()}"
|
|
)
|
|
print(f"Pool running on localhost:{pool.external_address}") # noqa: T201
|
|
|
|
endpoint = await start_supervisor_components(
|
|
pool.external_address, "127.0.0.1", xo.utils.get_next_port()
|
|
)
|
|
await start_worker_components(
|
|
address=pool.external_address, supervisor_address=pool.external_address
|
|
)
|
|
|
|
# wait for the api.
|
|
time.sleep(3)
|
|
async with pool:
|
|
yield endpoint, pool.external_address
|
|
|
|
|
|
def test_xinference_embedding_documents(setup: Tuple[str, str]) -> None:
|
|
"""Test xinference embeddings for documents."""
|
|
from xinference.client import RESTfulClient
|
|
|
|
endpoint, _ = setup
|
|
|
|
client = RESTfulClient(endpoint)
|
|
|
|
model_uid = client.launch_model(
|
|
model_name="vicuna-v1.3",
|
|
model_size_in_billions=7,
|
|
model_format="ggmlv3",
|
|
quantization="q4_0",
|
|
)
|
|
|
|
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
|
|
|
documents = ["foo bar", "bar foo"]
|
|
output = xinference.embed_documents(documents)
|
|
assert len(output) == 2
|
|
assert len(output[0]) == 4096
|
|
|
|
|
|
def test_xinference_embedding_query(setup: Tuple[str, str]) -> None:
|
|
"""Test xinference embeddings for query."""
|
|
from xinference.client import RESTfulClient
|
|
|
|
endpoint, _ = setup
|
|
|
|
client = RESTfulClient(endpoint)
|
|
|
|
model_uid = client.launch_model(
|
|
model_name="vicuna-v1.3", model_size_in_billions=7, quantization="q4_0"
|
|
)
|
|
|
|
xinference = XinferenceEmbeddings(server_url=endpoint, model_uid=model_uid)
|
|
|
|
document = "foo bar"
|
|
output = xinference.embed_query(document)
|
|
assert len(output) == 4096
|
|
|
|
|
|
def test_xinference_embedding() -> None:
|
|
embedding_model = XinferenceEmbeddings(
|
|
server_url="http://xinference-hostname:9997", model_uid="foo"
|
|
)
|
|
|
|
embedding_model.embed_documents(
|
|
texts=["hello", "i'm trying to upgrade xinference embedding"]
|
|
)
|