mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
75cba742e5
Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
200 lines
5.7 KiB
Python
200 lines
5.7 KiB
Python
import os
|
||
from contextlib import contextmanager
|
||
from typing import Generator
|
||
from unittest.mock import Mock
|
||
|
||
import pytest
|
||
from ai21 import AI21Client
|
||
from ai21.models import (
|
||
AnswerResponse,
|
||
ChatOutput,
|
||
ChatResponse,
|
||
Completion,
|
||
CompletionData,
|
||
CompletionFinishReason,
|
||
CompletionsResponse,
|
||
FinishReason,
|
||
Penalty,
|
||
RoleType,
|
||
SegmentationResponse,
|
||
)
|
||
from ai21.models.responses.segmentation_response import Segment
|
||
from pytest_mock import MockerFixture
|
||
|
||
J2_CHAT_MODEL_NAME = "j2-ultra"
|
||
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
|
||
DUMMY_API_KEY = "test_api_key"
|
||
|
||
BASIC_EXAMPLE_LLM_PARAMETERS = {
|
||
"num_results": 3,
|
||
"max_tokens": 20,
|
||
"min_tokens": 10,
|
||
"temperature": 0.5,
|
||
"top_p": 0.5,
|
||
"top_k_return": 0,
|
||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||
"count_penalty": Penalty(
|
||
scale=0.2,
|
||
apply_to_punctuation=True,
|
||
apply_to_emojis=True,
|
||
),
|
||
}
|
||
|
||
BASIC_EXAMPLE_CHAT_PARAMETERS = {
|
||
"num_results": 3,
|
||
"max_tokens": 20,
|
||
"min_tokens": 10,
|
||
"temperature": 0.5,
|
||
"top_p": 0.5,
|
||
"top_k_return": 0,
|
||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||
"count_penalty": Penalty(
|
||
scale=0.2,
|
||
apply_to_punctuation=True,
|
||
apply_to_emojis=True,
|
||
),
|
||
"n": 3,
|
||
}
|
||
|
||
SEGMENTS = [
|
||
Segment(
|
||
segment_type="normal_text",
|
||
segment_text=(
|
||
"The original full name of the franchise is Pocket Monsters "
|
||
"(ポケットモンスター, Poketto Monsutā), which was abbreviated to "
|
||
"Pokemon during development of the original games.\n\nWhen the "
|
||
"franchise was released internationally, the short form of the "
|
||
"title was used, with an acute accent (´) over the e to aid "
|
||
"in pronunciation."
|
||
),
|
||
),
|
||
Segment(
|
||
segment_type="normal_text",
|
||
segment_text=(
|
||
"Pokémon refers to both the franchise itself and the creatures "
|
||
"within its fictional universe.\n\nAs a noun, it is identical in "
|
||
"both the singular and plural, as is every individual species "
|
||
'name;[10] it is grammatically correct to say "one Pokémon" '
|
||
'and "many Pokémon", as well as "one Pikachu" and "many '
|
||
'Pikachu".\n\nIn English, Pokémon may be pronounced either '
|
||
"/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)."
|
||
),
|
||
),
|
||
]
|
||
|
||
|
||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
|
||
"num_results": 3,
|
||
"max_tokens": 20,
|
||
"min_tokens": 10,
|
||
"temperature": 0.5,
|
||
"top_p": 0.5,
|
||
"top_k_return": 0,
|
||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||
"count_penalty": Penalty(
|
||
scale=0.2,
|
||
apply_to_punctuation=True,
|
||
apply_to_emojis=True,
|
||
).to_dict(),
|
||
}
|
||
|
||
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
|
||
"num_results": 3,
|
||
"max_tokens": 20,
|
||
"min_tokens": 10,
|
||
"temperature": 0.5,
|
||
"top_p": 0.5,
|
||
"top_k_return": 0,
|
||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||
"count_penalty": Penalty(
|
||
scale=0.2,
|
||
apply_to_punctuation=True,
|
||
apply_to_emojis=True,
|
||
).to_dict(),
|
||
"n": 3,
|
||
}
|
||
|
||
|
||
@pytest.fixture
|
||
def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
||
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
||
mocked_response.prompt = "this is a test prompt"
|
||
mocked_response.completions = [
|
||
Completion(
|
||
data=CompletionData(text="test", tokens=[]),
|
||
finish_reason=CompletionFinishReason(reason=None, length=None),
|
||
)
|
||
]
|
||
return mocked_response
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_client_with_completion(
|
||
mocker: MockerFixture, mocked_completion_response: Mock
|
||
) -> Mock:
|
||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||
mock_client.completion = mocker.MagicMock()
|
||
mock_client.completion.create.side_effect = [
|
||
mocked_completion_response,
|
||
mocked_completion_response,
|
||
]
|
||
mock_client.count_tokens.side_effect = [10, 20]
|
||
|
||
return mock_client
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_client_with_chat(mocker: MockerFixture) -> Mock:
|
||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||
mock_client.chat = mocker.MagicMock()
|
||
|
||
output = ChatOutput(
|
||
text="Hello Pickle Rick!",
|
||
role=RoleType.ASSISTANT,
|
||
finish_reason=FinishReason(reason="testing"),
|
||
)
|
||
mock_client.chat.create.return_value = ChatResponse(outputs=[output])
|
||
|
||
return mock_client
|
||
|
||
|
||
@contextmanager
|
||
def temporarily_unset_api_key() -> Generator:
|
||
"""
|
||
Unset and set environment key for testing purpose for when an API KEY is not set
|
||
"""
|
||
api_key = os.environ.pop("AI21_API_KEY", None)
|
||
yield
|
||
|
||
if api_key is not None:
|
||
os.environ["AI21_API_KEY"] = api_key
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
|
||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||
mock_client.answer = mocker.MagicMock()
|
||
mock_client.answer.create.return_value = AnswerResponse(
|
||
id="some_id",
|
||
answer="some answer",
|
||
answer_in_context=False,
|
||
)
|
||
|
||
return mock_client
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock:
|
||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||
mock_client.segmentation = mocker.MagicMock()
|
||
mock_client.segmentation.create.return_value = SegmentationResponse(
|
||
id="12345",
|
||
segments=SEGMENTS,
|
||
)
|
||
|
||
return mock_client
|