2024-02-15 20:25:05 +00:00
|
|
|
"""Test ChatAI21 chat model."""
|
2024-05-01 17:12:44 +00:00
|
|
|
import pytest
|
2024-02-15 20:25:05 +00:00
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
from langchain_core.outputs import ChatGeneration
|
|
|
|
|
|
|
|
from langchain_ai21.chat_models import ChatAI21
|
2024-05-01 17:12:44 +00:00
|
|
|
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
2024-03-14 23:10:23 +00:00
|
|
|
|
2024-02-15 20:25:05 +00:00
|
|
|
|
2024-05-01 17:12:44 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
ids=[
|
|
|
|
"when_j2_model",
|
|
|
|
"when_jamba_model",
|
|
|
|
],
|
|
|
|
argnames=["model"],
|
|
|
|
argvalues=[
|
|
|
|
(J2_CHAT_MODEL_NAME,),
|
|
|
|
(JAMBA_CHAT_MODEL_NAME,),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_invoke(model: str) -> None:
|
2024-02-15 20:25:05 +00:00
|
|
|
"""Test invoke tokens from AI21."""
|
2024-05-01 17:12:44 +00:00
|
|
|
llm = ChatAI21(model=model)
|
2024-02-15 20:25:05 +00:00
|
|
|
|
|
|
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
|
|
|
|
|
2024-05-01 17:12:44 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
ids=[
|
|
|
|
"when_j2_model_num_results_is_1",
|
|
|
|
"when_j2_model_num_results_is_3",
|
|
|
|
"when_jamba_model_n_is_1",
|
|
|
|
"when_jamba_model_n_is_3",
|
|
|
|
],
|
|
|
|
argnames=["model", "num_results"],
|
|
|
|
argvalues=[
|
|
|
|
(J2_CHAT_MODEL_NAME, 1),
|
|
|
|
(J2_CHAT_MODEL_NAME, 3),
|
|
|
|
(JAMBA_CHAT_MODEL_NAME, 1),
|
|
|
|
(JAMBA_CHAT_MODEL_NAME, 3),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_generation(model: str, num_results: int) -> None:
|
|
|
|
"""Test generation with multiple models and different result counts."""
|
|
|
|
# Determine the configuration key based on the model type
|
|
|
|
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
|
2024-02-15 20:25:05 +00:00
|
|
|
|
2024-05-01 17:12:44 +00:00
|
|
|
# Create the model instance using the appropriate key for the result count
|
|
|
|
llm = ChatAI21(model=model, **{config_key: num_results})
|
|
|
|
|
|
|
|
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
|
|
|
|
|
|
|
|
result = llm.generate([[message]], config=dict(tags=["foo"]))
|
2024-02-15 20:25:05 +00:00
|
|
|
|
|
|
|
for generations in result.generations:
|
2024-05-01 17:12:44 +00:00
|
|
|
assert len(generations) == num_results
|
2024-02-15 20:25:05 +00:00
|
|
|
for generation in generations:
|
|
|
|
assert isinstance(generation, ChatGeneration)
|
|
|
|
assert isinstance(generation.text, str)
|
|
|
|
assert generation.text == generation.message.content
|
|
|
|
|
|
|
|
|
2024-05-01 17:12:44 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
ids=[
|
|
|
|
"when_j2_model",
|
|
|
|
"when_jamba_model",
|
|
|
|
],
|
|
|
|
argnames=["model"],
|
|
|
|
argvalues=[
|
|
|
|
(J2_CHAT_MODEL_NAME,),
|
|
|
|
(JAMBA_CHAT_MODEL_NAME,),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
async def test_ageneration(model: str) -> None:
|
2024-02-15 20:25:05 +00:00
|
|
|
"""Test invoke tokens from AI21."""
|
2024-05-01 17:12:44 +00:00
|
|
|
llm = ChatAI21(model=model)
|
2024-02-15 20:25:05 +00:00
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
|
|
|
|
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
|
|
|
|
|
|
|
for generations in result.generations:
|
|
|
|
assert len(generations) == 1
|
|
|
|
for generation in generations:
|
|
|
|
assert isinstance(generation, ChatGeneration)
|
|
|
|
assert isinstance(generation.text, str)
|
|
|
|
assert generation.text == generation.message.content
|