forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
"""Test functionality around the simple pipeline chain."""
|
|
|
|
from typing import Dict, List
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.simple_pipeline import SimplePipeline
|
|
|
|
|
|
class FakeChain(Chain, BaseModel):
|
|
"""Fake chain for testing purposes."""
|
|
|
|
input_variables: List[str]
|
|
output_variables: List[str]
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Input keys this chain returns."""
|
|
return self.input_variables
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Input keys this chain returns."""
|
|
return self.output_variables
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
outputs = {}
|
|
for var in self.output_variables:
|
|
variables = [inputs[k] for k in self.input_variables]
|
|
outputs[var] = " ".join(variables) + "foo"
|
|
return outputs
|
|
|
|
|
|
def test_pipeline_functionality() -> None:
|
|
"""Test simple pipeline functionality."""
|
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
|
pipeline = SimplePipeline(chains=[chain_1, chain_2])
|
|
output = pipeline({"input": "123"})
|
|
expected_output = {"output": "123foofoo", "input": "123"}
|
|
assert output == expected_output
|
|
|
|
|
|
def test_multi_input_errors() -> None:
|
|
"""Test pipeline errors if multiple input variables are expected."""
|
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
|
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
|
with pytest.raises(ValueError):
|
|
SimplePipeline(chains=[chain_1, chain_2])
|
|
|
|
|
|
def test_multi_output_errors() -> None:
|
|
"""Test pipeline errors if multiple output variables are expected."""
|
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
|
with pytest.raises(ValueError):
|
|
SimplePipeline(chains=[chain_1, chain_2])
|