You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/ai21/tests/unit_tests/test_embeddings.py

68 lines
2.1 KiB
Python

"""Test embedding model integration."""
from unittest.mock import Mock
import pytest
from ai21 import AI21Client, MissingApiKeyError
from ai21.models import EmbedResponse, EmbedResult, EmbedType
from pytest_mock import MockerFixture
from langchain_ai21.embeddings import AI21Embeddings
from tests.unit_tests.conftest import DUMMY_API_KEY, temporarily_unset_api_key
_EXAMPLE_EMBEDDING_0 = [1.0, 2.0, 3.0]
_EXAMPLE_EMBEDDING_1 = [4.0, 5.0, 6.0]
_EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[
EmbedResult(_EXAMPLE_EMBEDDING_0),
EmbedResult(_EXAMPLE_EMBEDDING_1),
EmbedResult(_EXAMPLE_EMBEDDING_2),
],
id="test_id",
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
AI21Embeddings()
@pytest.fixture
def mock_client_with_embeddings(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.embed = mocker.MagicMock()
mock_client.embed.create.return_value = _EXAMPLE_EMBEDDING_RESPONSE
return mock_client
def test_embed_query(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
text = "Hello embeddings world!"
response = llm.embed_query(text=text)
assert response == _EXAMPLE_EMBEDDING_0
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=[text],
type=EmbedType.QUERY,
)
def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
texts = ["Hello embeddings world!", "Some other text", "Some more text"]
response = llm.embed_documents(texts=texts)
assert response == [
_EXAMPLE_EMBEDDING_0,
_EXAMPLE_EMBEDDING_1,
_EXAMPLE_EMBEDDING_2,
]
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=texts,
type=EmbedType.SEGMENT,
)