langchain/libs/partners/ai21/tests/unit_tests/test_embeddings.py
Asaf Joseph Gardin 4d7f6fa968
ai21[patch]: AI21 Labs Batch Support in Embeddings (#18633)
Description: Added support for batching when using AI21 Embeddings model
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-03-14 23:10:23 +00:00

103 lines
3.2 KiB
Python

"""Test embedding model integration."""
from typing import List
from unittest.mock import Mock
import pytest
from ai21 import AI21Client, MissingApiKeyError
from ai21.models import EmbedResponse, EmbedResult, EmbedType
from pytest_mock import MockerFixture
from langchain_ai21.embeddings import AI21Embeddings
from tests.unit_tests.conftest import DUMMY_API_KEY, temporarily_unset_api_key
_EXAMPLE_EMBEDDING_0 = [1.0, 2.0, 3.0]
_EXAMPLE_EMBEDDING_1 = [4.0, 5.0, 6.0]
_EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[
EmbedResult(_EXAMPLE_EMBEDDING_0),
EmbedResult(_EXAMPLE_EMBEDDING_1),
EmbedResult(_EXAMPLE_EMBEDDING_2),
],
id="test_id",
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
AI21Embeddings()
@pytest.fixture
def mock_client_with_embeddings(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.embed = mocker.MagicMock()
mock_client.embed.create.return_value = _EXAMPLE_EMBEDDING_RESPONSE
return mock_client
def test_embed_query(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
text = "Hello embeddings world!"
response = llm.embed_query(text=text)
assert response == _EXAMPLE_EMBEDDING_0
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=[text],
type=EmbedType.QUERY,
)
def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
texts = ["Hello embeddings world!", "Some other text", "Some more text"]
response = llm.embed_documents(texts=texts)
assert response == [
_EXAMPLE_EMBEDDING_0,
_EXAMPLE_EMBEDDING_1,
_EXAMPLE_EMBEDDING_2,
]
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=texts,
type=EmbedType.SEGMENT,
)
@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
llm.embed_documents(texts=texts, batch_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)