langchain/tests/unit_tests/chains/test_base.py
2022-10-24 14:51:15 -07:00

51 lines
1.2 KiB
Python

"""Test logic on base chain class."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
class FakeChain(Chain, BaseModel):
"""Fake chain class for testing purposes."""
be_correct: bool = True
@property
def input_keys(self) -> List[str]:
"""Input key of foo."""
return ["foo"]
@property
def output_keys(self) -> List[str]:
"""Output key of bar."""
return ["bar"]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
if self.be_correct:
return {"bar": "baz"}
else:
return {"baz": "bar"}
def test_bad_inputs() -> None:
"""Test errors are raised if input keys are not found."""
chain = FakeChain()
with pytest.raises(ValueError):
chain({"foobar": "baz"})
def test_bad_outputs() -> None:
"""Test errors are raised if outputs keys are not found."""
chain = FakeChain(be_correct=False)
with pytest.raises(ValueError):
chain({"foo": "baz"})
def test_correct_call() -> None:
"""Test correct call of fake chain."""
chain = FakeChain()
output = chain({"foo": "bar"})
assert output == {"foo": "bar", "bar": "baz"}