langchain/tests/unit_tests/chains/test_base.py
Harrison Chase 8cf62ce06e
Harrison/single input (#347)
allow passing of single input into chain

Co-authored-by: thepok <richterthepok@yahoo.de>
2022-12-15 07:52:51 -08:00

66 lines
1.6 KiB
Python

"""Test logic on base chain class."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
class FakeChain(Chain, BaseModel):
"""Fake chain class for testing purposes."""
be_correct: bool = True
the_input_keys: List[str] = ["foo"]
@property
def input_keys(self) -> List[str]:
"""Input keys."""
return self.the_input_keys
@property
def output_keys(self) -> List[str]:
"""Output key of bar."""
return ["bar"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
if self.be_correct:
return {"bar": "baz"}
else:
return {"baz": "bar"}
def test_bad_inputs() -> None:
"""Test errors are raised if input keys are not found."""
chain = FakeChain()
with pytest.raises(ValueError):
chain({"foobar": "baz"})
def test_bad_outputs() -> None:
"""Test errors are raised if outputs keys are not found."""
chain = FakeChain(be_correct=False)
with pytest.raises(ValueError):
chain({"foo": "baz"})
def test_correct_call() -> None:
"""Test correct call of fake chain."""
chain = FakeChain()
output = chain({"foo": "bar"})
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_correct() -> None:
"""Test passing single input works."""
chain = FakeChain()
output = chain("bar")
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_error() -> None:
"""Test passing single input errors as expected."""
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain("bar")