langchain/libs/community/tests/integration_tests/chat_models/test_yuan2.py
Eugene Yurtsev 25fbe356b4
community[patch]: upgrade to recent version of mypy (#21616)
This PR upgrades community to a recent version of mypy. It inserts type:
ignore on all existing failures.
2024-05-13 14:55:07 -04:00

153 lines
4.7 KiB
Python

"""Test ChatYuan2 wrapper."""
from typing import List
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
)
from langchain_community.chat_models.yuan2 import ChatYuan2
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.scheduled
def test_chat_yuan2() -> None:
"""Test ChatYuan2 wrapper."""
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat.invoke(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_yuan2_system_message() -> None:
"""Test ChatYuan2 wrapper with system message."""
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello"),
]
response = chat.invoke(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
@pytest.mark.scheduled
def test_chat_yuan2_generate() -> None:
"""Test ChatYuan2 wrapper with generate."""
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = chat.generate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert response.llm_output
generation = response.generations[0]
for gen in generation:
assert isinstance(gen, ChatGeneration)
assert isinstance(gen.text, str)
assert gen.text == gen.message.content
@pytest.mark.scheduled
def test_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callbacks=callback_manager,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat.invoke(messages)
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)
@pytest.mark.asyncio
async def test_async_chat_yuan2() -> None:
"""Test async generation."""
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2( # type: ignore[call-arg]
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callbacks=callback_manager,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content