langchain/tests/unit_tests/chains/test_base.py

66 lines
1.6 KiB
Python
Raw Normal View History

2022-10-24 21:51:15 +00:00
"""Test logic on base chain class."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
class FakeChain(Chain, BaseModel):
"""Fake chain class for testing purposes."""
be_correct: bool = True
the_input_keys: List[str] = ["foo"]
2022-10-24 21:51:15 +00:00
@property
def input_keys(self) -> List[str]:
"""Input keys."""
return self.the_input_keys
2022-10-24 21:51:15 +00:00
@property
def output_keys(self) -> List[str]:
"""Output key of bar."""
return ["bar"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
2022-10-24 21:51:15 +00:00
if self.be_correct:
return {"bar": "baz"}
else:
return {"baz": "bar"}
def test_bad_inputs() -> None:
"""Test errors are raised if input keys are not found."""
chain = FakeChain()
with pytest.raises(ValueError):
chain({"foobar": "baz"})
def test_bad_outputs() -> None:
"""Test errors are raised if outputs keys are not found."""
chain = FakeChain(be_correct=False)
with pytest.raises(ValueError):
chain({"foo": "baz"})
def test_correct_call() -> None:
"""Test correct call of fake chain."""
chain = FakeChain()
output = chain({"foo": "bar"})
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_correct() -> None:
"""Test passing single input works."""
chain = FakeChain()
output = chain("bar")
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_error() -> None:
"""Test passing single input errors as expected."""
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain("bar")