You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py

146 lines
4.0 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
class Person(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
class ChatModelTests(ABC):
@property
@abstractmethod
def chat_model_class(self) -> Type[BaseChatModel]:
...
@property
def chat_model_params(self) -> dict:
return {}
@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"timeout": 60,
"stop": [],
"max_retries": 2,
}
@pytest.fixture
def model(self) -> BaseChatModel:
return self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
@property
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
@property
def has_structured_output(self) -> bool:
return (
self.chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
@property
def supports_image_inputs(self) -> bool:
return False
@property
def supports_video_inputs(self) -> bool:
return False
@property
def returns_usage_metadata(self) -> bool:
return True
class ChatModelUnitTests(ChatModelTests):
@property
def standard_chat_model_params(self) -> dict:
params = super().standard_chat_model_params
params["api_key"] = "test"
return params
def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None
def test_init_streaming(
self,
) -> None:
model = self.chat_model_class(
**{
**self.standard_chat_model_params,
**self.chat_model_params,
"streaming": True,
}
)
assert model is not None
def test_bind_tool_pydantic(
self,
model: BaseChatModel,
) -> None:
if not self.has_tool_calling:
return
tool_model = model.bind_tools(
[Person, Person.schema(), my_adder_tool], tool_choice="any"
)
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", [Person, Person.schema()])
def test_with_structured_output(
self,
model: BaseChatModel,
schema: Any,
) -> None:
if not self.has_structured_output:
return
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
# Test optional params
model = self.chat_model_class(
max_tokens=10, stop=["test"], **self.chat_model_params
)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")