You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/embeddings/test_infinity.py

102 lines
2.5 KiB
Python

from typing import Dict
from pytest_mock import MockerFixture
from langchain_community.embeddings import InfinityEmbeddings
_MODEL_ID = "BAAI/bge-small"
_INFINITY_BASE_URL = "https://localhost/api"
_DOCUMENTS = [
"pizza",
"another pizza",
"a document",
"another pizza",
"super long document with many tokens",
]
class MockResponse:
def __init__(self, json_data: Dict, status_code: int):
self.json_data = json_data
self.status_code = status_code
def json(self) -> Dict:
return self.json_data
def mocked_requests_post(
url: str,
headers: dict,
json: dict,
) -> MockResponse:
assert url.startswith(_INFINITY_BASE_URL)
assert "model" in json and _MODEL_ID in json["model"]
assert json
assert headers
assert "input" in json and isinstance(json["input"], list)
embeddings = []
for inp in json["input"]:
# verify correct ordering
if "pizza" in inp:
v = [1.0, 0.0, 0.0]
elif "document" in inp:
v = [0.0, 0.9, 0.0]
else:
v = [0.0, 0.0, -1.0]
if len(inp) > 10:
v[2] += 0.1
embeddings.append({"embedding": v})
return MockResponse(
json_data={"data": embeddings},
status_code=200,
)
def test_infinity_emb_sync(
mocker: MockerFixture,
) -> None:
mocker.patch("requests.post", side_effect=mocked_requests_post)
embedder = InfinityEmbeddings(model=_MODEL_ID, infinity_api_url=_INFINITY_BASE_URL)
assert embedder.infinity_api_url == _INFINITY_BASE_URL
assert embedder.model == _MODEL_ID
response = embedder.embed_documents(_DOCUMENTS)
want = [
[1.0, 0.0, 0.0], # pizza
[1.0, 0.0, 0.1], # pizza + long
[0.0, 0.9, 0.0], # doc
[1.0, 0.0, 0.1], # pizza + long
[0.0, 0.9, 0.1], # doc + long
]
assert response == want
def test_infinity_large_batch_size(
mocker: MockerFixture,
) -> None:
mocker.patch("requests.post", side_effect=mocked_requests_post)
embedder = InfinityEmbeddings(
infinity_api_url=_INFINITY_BASE_URL,
model=_MODEL_ID,
)
assert embedder.infinity_api_url == _INFINITY_BASE_URL
assert embedder.model == _MODEL_ID
response = embedder.embed_documents(_DOCUMENTS * 1024)
want = [
[1.0, 0.0, 0.0], # pizza
[1.0, 0.0, 0.1], # pizza + long
[0.0, 0.9, 0.0], # doc
[1.0, 0.0, 0.1], # pizza + long
[0.0, 0.9, 0.1], # doc + long
] * 1024
assert response == want