From 620484f3eac563d8ed0ba295a8df47aa99689267 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 19 Nov 2022 09:36:22 -0800 Subject: [PATCH] cr --- langchain/chains/map/base.py | 7 ++++--- langchain/chains/pipeline.py | 20 ++++++++++++-------- langchain/chains/simple_pipeline.py | 10 +++++----- tests/unit_tests/chains/test_pipeline.py | 7 ++++--- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/langchain/chains/map/base.py b/langchain/chains/map/base.py index d2c3beda..234abdde 100644 --- a/langchain/chains/map/base.py +++ b/langchain/chains/map/base.py @@ -1,9 +1,11 @@ """Chain that generates a list and then maps each output to another chain.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator + from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from pydantic import BaseModel, Extra, root_validator -from typing import List, Dict class MapChain(Chain, BaseModel): @@ -70,4 +72,3 @@ class MapChain(Chain, BaseModel): ) outputs = {self.map_chain.run(text) for text in new_inputs} return outputs - diff --git a/langchain/chains/pipeline.py b/langchain/chains/pipeline.py index f2a92635..2d0fcc77 100644 --- a/langchain/chains/pipeline.py +++ b/langchain/chains/pipeline.py @@ -1,8 +1,10 @@ -"""Simple chain pipeline where the outputs of one step feed directly into next.""" +"""Chain pipeline where the outputs of one step feed directly into next.""" + +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain -from pydantic import BaseModel, Extra, root_validator -from typing import List, Dict class Pipeline(Chain, BaseModel): @@ -36,7 +38,7 @@ class Pipeline(Chain, BaseModel): @root_validator(pre=True) def validate_chains(cls, values: Dict) -> Dict: - """Validate that chains are all single input/output.""" + """Validate that the correct inputs exist for all chains.""" chains = values["chains"] input_variables = values["input_variables"] known_variables = set(input_variables) @@ -46,7 +48,9 @@ class Pipeline(Chain, BaseModel): raise ValueError(f"Missing required input keys: {missing_vars}") overlapping_keys = known_variables.intersection(chain.output_keys) if overlapping_keys: - raise ValueError(f"Chain returned keys that already exist: {overlapping_keys}") + raise ValueError( + f"Chain returned keys that already exist: {overlapping_keys}" + ) known_variables |= set(chain.output_keys) if "output_variables" not in values: @@ -54,7 +58,9 @@ class Pipeline(Chain, BaseModel): else: missing_vars = known_variables.difference(values["output_variables"]) if missing_vars: - raise ValueError(f"Expected output variables that were not found: {missing_vars}.") + raise ValueError( + f"Expected output variables that were not found: {missing_vars}." + ) return values def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: @@ -63,5 +69,3 @@ class Pipeline(Chain, BaseModel): outputs = chain(known_values) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} - - diff --git a/langchain/chains/simple_pipeline.py b/langchain/chains/simple_pipeline.py index a36b1933..0ff34e7b 100644 --- a/langchain/chains/simple_pipeline.py +++ b/langchain/chains/simple_pipeline.py @@ -1,8 +1,10 @@ """Simple chain pipeline where the outputs of one step feed directly into next.""" -from langchain.chains.base import Chain +from typing import Dict, List + from pydantic import BaseModel, Extra, root_validator -from typing import List, Dict + +from langchain.chains.base import Chain class SimplePipeline(Chain, BaseModel): @@ -55,7 +57,5 @@ class SimplePipeline(Chain, BaseModel): for chain in self.chains: _input = chain.run(_input) # Clean the input - _input = _input.strip(' \t\n\r') + _input = _input.strip() return {self.output_key: _input} - - diff --git a/tests/unit_tests/chains/test_pipeline.py b/tests/unit_tests/chains/test_pipeline.py index b4c27674..70b10d6e 100644 --- a/tests/unit_tests/chains/test_pipeline.py +++ b/tests/unit_tests/chains/test_pipeline.py @@ -1,9 +1,10 @@ from typing import Dict, List -from langchain.chains.pipeline import Pipeline -from langchain.chains.base import Chain from pydantic import BaseModel +from langchain.chains.base import Chain +from langchain.chains.pipeline import Pipeline + class FakeChain(Chain, BaseModel): @@ -30,4 +31,4 @@ def test_pipeline_usage() -> None: chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"]) output = pipeline({"foo": "123"}) - breakpoint() \ No newline at end of file + breakpoint()