langchain/libs/community/tests/unit_tests/chat_models/test_ollama.py

36 lines
1.1 KiB
Python

from typing import List, Literal, Optional
import pytest
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_community.chat_models import ChatOllama
def test_standard_params() -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
model = ChatOllama(model="llama3")
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
assert ls_params["ls_model_name"] == "llama3"
# Test optional params
model = ChatOllama(num_predict=10, stop=["test"], temperature=0.33)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
assert ls_params["ls_max_tokens"] == 10
assert ls_params["ls_stop"] == ["test"]
assert ls_params["ls_temperature"] == 0.33