langchain/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py

118 lines
3.7 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.tools import tool
class Person(BaseModel):
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
class ChatModelUnitTests(ABC):
@abstractmethod
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
...
@pytest.fixture
def chat_model_params(self) -> dict:
return {}
@pytest.fixture
def chat_model_has_tool_calling(
self, chat_model_class: Type[BaseChatModel]
) -> bool:
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
@pytest.fixture
def chat_model_has_structured_output(
self, chat_model_class: Type[BaseChatModel]
) -> bool:
return (
chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
def test_chat_model_init(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
assert model is not None
def test_chat_model_init_api_key(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
params = {**chat_model_params, "api_key": "test"}
model = chat_model_class(**params) # type: ignore
assert model is not None
def test_chat_model_init_streaming(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(streaming=True, **chat_model_params) # type: ignore
assert model is not None
def test_chat_model_bind_tool_pydantic(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
if not chat_model_has_tool_calling:
return
model = chat_model_class(**chat_model_params)
assert hasattr(model, "bind_tools")
tool_model = model.bind_tools([Person])
assert tool_model is not None
def test_chat_model_with_structured_output(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_structured_output: bool,
) -> None:
if not chat_model_has_structured_output:
return
model = chat_model_class(**chat_model_params)
assert model is not None
assert model.with_structured_output(Person) is not None
def test_standard_params(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
model = chat_model_class(**chat_model_params)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
# Test optional params
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")