You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/simple_pipeline.py

60 lines
1.8 KiB
Python

"""Simple chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
class SimplePipeline(Chain, BaseModel):
"""Simple chain pipeline where the outputs of one step feed directly into next."""
chains: List[Chain]
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that chains are all single input/output."""
for chain in values["chains"]:
if len(chain.input_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one input, got "
f"{chain} with {len(chain.input_keys)} inputs."
)
if len(chain.output_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one output, got "
f"{chain} with {len(chain.output_keys)} outputs."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
for chain in self.chains:
_input = chain.run(_input)
return {self.output_key: _input}