forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
|
|
|
from typing import Dict, List
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
|
class SimplePipeline(Chain, BaseModel):
|
|
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
|
|
|
chains: List[Chain]
|
|
input_key: str = "input" #: :meta private:
|
|
output_key: str = "output" #: :meta private:
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Expect input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
@root_validator()
|
|
def validate_chains(cls, values: Dict) -> Dict:
|
|
"""Validate that chains are all single input/output."""
|
|
for chain in values["chains"]:
|
|
if len(chain.input_keys) != 1:
|
|
raise ValueError(
|
|
"Chains used in SimplePipeline should all have one input, got "
|
|
f"{chain} with {len(chain.input_keys)} inputs."
|
|
)
|
|
if len(chain.output_keys) != 1:
|
|
raise ValueError(
|
|
"Chains used in SimplePipeline should all have one output, got "
|
|
f"{chain} with {len(chain.output_keys)} outputs."
|
|
)
|
|
return values
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
_input = inputs[self.input_key]
|
|
for chain in self.chains:
|
|
_input = chain.run(_input)
|
|
return {self.output_key: _input}
|