You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/unit_tests/chains/test_simple_pipeline.py

60 lines
2.0 KiB
Python

"""Test functionality around the simple pipeline chain."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.simple_pipeline import SimplePipeline
class FakeChain(Chain, BaseModel):
"""Fake chain for testing purposes."""
input_variables: List[str]
output_variables: List[str]
@property
def input_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]
outputs[var] = " ".join(variables) + "foo"
return outputs
def test_pipeline_functionality() -> None:
"""Test simple pipeline functionality."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
pipeline = SimplePipeline(chains=[chain_1, chain_2])
output = pipeline({"input": "123"})
expected_output = {"output": "123foofoo", "input": "123"}
assert output == expected_output
def test_multi_input_errors() -> None:
"""Test pipeline errors if multiple input variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
with pytest.raises(ValueError):
SimplePipeline(chains=[chain_1, chain_2])
def test_multi_output_errors() -> None:
"""Test pipeline errors if multiple output variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
SimplePipeline(chains=[chain_1, chain_2])