harrison/chain_pipeline
Harrison Chase 2 years ago
parent 3fcc803880
commit 620484f3ea

@ -1,9 +1,11 @@
"""Chain that generates a list and then maps each output to another chain."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
class MapChain(Chain, BaseModel):
@ -70,4 +72,3 @@ class MapChain(Chain, BaseModel):
)
outputs = {self.map_chain.run(text) for text in new_inputs}
return outputs

@ -1,8 +1,10 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from langchain.chains.base import Chain
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
from langchain.chains.base import Chain
class Pipeline(Chain, BaseModel):
@ -36,7 +38,7 @@ class Pipeline(Chain, BaseModel):
@root_validator(pre=True)
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that chains are all single input/output."""
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
known_variables = set(input_variables)
@ -46,7 +48,9 @@ class Pipeline(Chain, BaseModel):
raise ValueError(f"Missing required input keys: {missing_vars}")
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(f"Chain returned keys that already exist: {overlapping_keys}")
raise ValueError(
f"Chain returned keys that already exist: {overlapping_keys}"
)
known_variables |= set(chain.output_keys)
if "output_variables" not in values:
@ -54,7 +58,9 @@ class Pipeline(Chain, BaseModel):
else:
missing_vars = known_variables.difference(values["output_variables"])
if missing_vars:
raise ValueError(f"Expected output variables that were not found: {missing_vars}.")
raise ValueError(
f"Expected output variables that were not found: {missing_vars}."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
@ -63,5 +69,3 @@ class Pipeline(Chain, BaseModel):
outputs = chain(known_values)
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}

@ -1,8 +1,10 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
from langchain.chains.base import Chain
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
from langchain.chains.base import Chain
class SimplePipeline(Chain, BaseModel):
@ -55,7 +57,5 @@ class SimplePipeline(Chain, BaseModel):
for chain in self.chains:
_input = chain.run(_input)
# Clean the input
_input = _input.strip(' \t\n\r')
_input = _input.strip()
return {self.output_key: _input}

@ -1,9 +1,10 @@
from typing import Dict, List
from langchain.chains.pipeline import Pipeline
from langchain.chains.base import Chain
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.pipeline import Pipeline
class FakeChain(Chain, BaseModel):
@ -30,4 +31,4 @@ def test_pipeline_usage() -> None:
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
output = pipeline({"foo": "123"})
breakpoint()
breakpoint()

Loading…
Cancel
Save