forked from Archives/langchain
chain pipelines
parent
8869b0ab0e
commit
b5325c212b
@ -0,0 +1,71 @@
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class Pipeline(Chain, BaseModel):
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
input_variables: List[str]
|
||||
output_variables: List[str] #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.output_variables
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chains = values["chains"]
|
||||
input_variables = values["input_variables"]
|
||||
known_variables = set(input_variables)
|
||||
for chain in chains:
|
||||
missing_vars = set(chain.input_keys).difference(known_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required input keys: {missing_vars}")
|
||||
overlapping_keys = known_variables.intersection(chain.output_keys)
|
||||
if overlapping_keys:
|
||||
raise ValueError(
|
||||
f"Chain returned keys that already exist: {overlapping_keys}"
|
||||
)
|
||||
known_variables |= set(chain.output_keys)
|
||||
|
||||
if "output_variables" not in values:
|
||||
values["output_variables"] = known_variables.difference(input_variables)
|
||||
else:
|
||||
missing_vars = known_variables.difference(values["output_variables"])
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Expected output variables that were not found: {missing_vars}."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
known_values = inputs.copy()
|
||||
for chain in self.chains:
|
||||
outputs = chain(known_values)
|
||||
known_values.update(outputs)
|
||||
return {k: known_values[k] for k in self.output_variables}
|
@ -0,0 +1,59 @@
|
||||
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class SimplePipeline(Chain, BaseModel):
|
||||
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that chains are all single input/output."""
|
||||
for chain in values["chains"]:
|
||||
if len(chain.input_keys) != 1:
|
||||
raise ValueError(
|
||||
"Chains used in SimplePipeline should all have one input, got "
|
||||
f"{chain} with {len(chain.input_keys)} inputs."
|
||||
)
|
||||
if len(chain.output_keys) != 1:
|
||||
raise ValueError(
|
||||
"Chains used in SimplePipeline should all have one output, got "
|
||||
f"{chain} with {len(chain.output_keys)} outputs."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
_input = inputs[self.input_key]
|
||||
for chain in self.chains:
|
||||
_input = chain.run(_input)
|
||||
return {self.output_key: _input}
|
@ -0,0 +1,103 @@
|
||||
"""Test pipeline functionality."""
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.pipeline import Pipeline
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
"""Fake Chain for testing purposes."""
|
||||
|
||||
input_variables: List[str]
|
||||
output_variables: List[str]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
outputs[var] = " ".join(variables) + "foo"
|
||||
return outputs
|
||||
|
||||
|
||||
def test_pipeline_usage_single_inputs() -> None:
|
||||
"""Test pipeline on single input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = pipeline({"foo": "123"})
|
||||
expected_output = {"bar": "123foo", "baz": "123foofoo", "foo": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_pipeline_usage_multiple_inputs() -> None:
|
||||
"""Test pipeline on multiple input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||
output = pipeline({"foo": "123", "test": "456"})
|
||||
expected_output = {
|
||||
"bar": "123 456foo",
|
||||
"baz": "123 456foo 123foo",
|
||||
"foo": "123",
|
||||
"test": "456",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_pipeline_usage_multiple_outputs() -> None:
|
||||
"""Test pipeline usage on multiple output chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = pipeline({"foo": "123"})
|
||||
expected_output = {
|
||||
"bar": "123foo",
|
||||
"baz": "123foo 123foo",
|
||||
"foo": "123",
|
||||
"test": "123foo",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_pipeline_missing_inputs() -> None:
|
||||
"""Test error is raised when input variables are missing."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# Also needs "test" as an input
|
||||
Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
|
||||
|
||||
def test_pipeline_bad_outputs() -> None:
|
||||
"""Test error is raised when bad outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is not present as an output variable.
|
||||
Pipeline(
|
||||
chains=[chain_1, chain_2],
|
||||
input_variables=["foo"],
|
||||
output_variables=["test"],
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_overlapping_inputs() -> None:
|
||||
"""Test error is raised when input variables are overlapping."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is specified as an input, but also is an output of one step
|
||||
Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
@ -0,0 +1,59 @@
|
||||
"""Test functionality around the simple pipeline chain."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.simple_pipeline import SimplePipeline
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
"""Fake chain for testing purposes."""
|
||||
|
||||
input_variables: List[str]
|
||||
output_variables: List[str]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
outputs[var] = " ".join(variables) + "foo"
|
||||
return outputs
|
||||
|
||||
|
||||
def test_pipeline_functionality() -> None:
|
||||
"""Test simple pipeline functionality."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
pipeline = SimplePipeline(chains=[chain_1, chain_2])
|
||||
output = pipeline({"input": "123"})
|
||||
expected_output = {"output": "123foofoo", "input": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_multi_input_errors() -> None:
|
||||
"""Test pipeline errors if multiple input variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimplePipeline(chains=[chain_1, chain_2])
|
||||
|
||||
|
||||
def test_multi_output_errors() -> None:
|
||||
"""Test pipeline errors if multiple output variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimplePipeline(chains=[chain_1, chain_2])
|
Loading…
Reference in New Issue