You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/chat_models/test_baichuan.py

137 lines
4.3 KiB
Python

from typing import cast
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463) Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
9 months ago
from langchain_community.chat_models.baichuan import (
ChatBaichuan,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
_convert_message_to_dict,
)
def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatBaichuan(
model="Baichuan2-Turbo-192K", baichuan_api_key="test-api-key", timeout=40
),
ChatBaichuan(
model="Baichuan2-Turbo-192K",
baichuan_api_key="test-api-key",
request_timeout=40,
),
]:
assert model.model == "Baichuan2-Turbo-192K"
assert model.request_timeout == 40
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
def test__convert_message_to_dict_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
def test__convert_dict_to_message_human() -> None:
message_dict = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message_dict = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
assert result == expected_output
def test__convert_delta_to_message_assistant() -> None:
delta = {"role": "assistant", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, AIMessageChunk)
expected_output = AIMessageChunk(content="foo")
assert result == expected_output
def test__convert_delta_to_message_human() -> None:
delta = {"role": "user", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
expected_output = HumanMessageChunk(content="foo")
assert result == expected_output
def test_baichuan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
chat = ChatBaichuan()
print(chat.baichuan_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_baichuan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = ChatBaichuan(baichuan_api_key="test-api-key")
print(chat.baichuan_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatBaichuan(
baichuan_api_key="test-api-key",
baichuan_secret_key="test-secret-key", # For backward compatibility
)
assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key"
assert (
cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
== "test-secret-key"
)