You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/ai21/tests/unit_tests/test_embeddings.py

103 lines
3.2 KiB
Python

"""Test embedding model integration."""
from typing import List
from unittest.mock import Mock
import pytest
from ai21 import AI21Client, MissingApiKeyError
from ai21.models import EmbedResponse, EmbedResult, EmbedType
from pytest_mock import MockerFixture
from langchain_ai21.embeddings import AI21Embeddings
from tests.unit_tests.conftest import DUMMY_API_KEY, temporarily_unset_api_key
_EXAMPLE_EMBEDDING_0 = [1.0, 2.0, 3.0]
_EXAMPLE_EMBEDDING_1 = [4.0, 5.0, 6.0]
_EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[
EmbedResult(_EXAMPLE_EMBEDDING_0),
EmbedResult(_EXAMPLE_EMBEDDING_1),
EmbedResult(_EXAMPLE_EMBEDDING_2),
],
id="test_id",
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
AI21Embeddings()
@pytest.fixture
def mock_client_with_embeddings(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.embed = mocker.MagicMock()
mock_client.embed.create.return_value = _EXAMPLE_EMBEDDING_RESPONSE
return mock_client
def test_embed_query(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
text = "Hello embeddings world!"
response = llm.embed_query(text=text)
assert response == _EXAMPLE_EMBEDDING_0
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=[text],
type=EmbedType.QUERY,
)
def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
texts = ["Hello embeddings world!", "Some other text", "Some more text"]
response = llm.embed_documents(texts=texts)
assert response == [
_EXAMPLE_EMBEDDING_0,
_EXAMPLE_EMBEDDING_1,
_EXAMPLE_EMBEDDING_2,
]
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=texts,
type=EmbedType.SEGMENT,
)
@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
llm.embed_documents(texts=texts, batch_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)