langchain/libs/partners/ai21/tests/integration_tests/test_chat_models.py
Asaf Joseph Gardin 642975dd9f
partners: AI21 Labs Jamba Support (#20815)
Description: Added support for AI21 new model - Jamba
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-05-01 10:12:44 -07:00

88 lines
2.7 KiB
Python

"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
],
)
def test_invoke(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
@pytest.mark.parametrize(
ids=[
"when_j2_model_num_results_is_1",
"when_j2_model_num_results_is_3",
"when_jamba_model_n_is_1",
"when_jamba_model_n_is_3",
],
argnames=["model", "num_results"],
argvalues=[
(J2_CHAT_MODEL_NAME, 1),
(J2_CHAT_MODEL_NAME, 3),
(JAMBA_CHAT_MODEL_NAME, 1),
(JAMBA_CHAT_MODEL_NAME, 3),
],
)
def test_generation(model: str, num_results: int) -> None:
"""Test generation with multiple models and different result counts."""
# Determine the configuration key based on the model type
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, **{config_key: num_results})
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
result = llm.generate([[message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == num_results
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
],
)
async def test_ageneration(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model)
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content