langchain/libs/partners/ai21/tests/integration_tests/test_chat_models.py

44 lines
1.4 KiB
Python
Raw Normal View History

"""Test ChatAI21 chat model."""
from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_ai21.chat_models import ChatAI21
def test_invoke() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
message = HumanMessage(content="Hello")
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
async def test_ageneration() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content