This commit is contained in:
Harrison Chase 2022-11-19 09:36:22 -08:00
parent 3fcc803880
commit 620484f3ea
4 changed files with 25 additions and 19 deletions

View File

@ -1,9 +1,11 @@
"""Chain that generates a list and then maps each output to another chain."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
class MapChain(Chain, BaseModel):
@ -70,4 +72,3 @@ class MapChain(Chain, BaseModel):
)
outputs = {self.map_chain.run(text) for text in new_inputs}
return outputs

View File

@ -1,8 +1,10 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
class Pipeline(Chain, BaseModel):
@ -36,7 +38,7 @@ class Pipeline(Chain, BaseModel):
@root_validator(pre=True)
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that chains are all single input/output."""
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
known_variables = set(input_variables)
@ -46,7 +48,9 @@ class Pipeline(Chain, BaseModel):
raise ValueError(f"Missing required input keys: {missing_vars}")
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(f"Chain returned keys that already exist: {overlapping_keys}")
raise ValueError(
f"Chain returned keys that already exist: {overlapping_keys}"
)
known_variables |= set(chain.output_keys)
if "output_variables" not in values:
@ -54,7 +58,9 @@ class Pipeline(Chain, BaseModel):
else:
missing_vars = known_variables.difference(values["output_variables"])
if missing_vars:
raise ValueError(f"Expected output variables that were not found: {missing_vars}.")
raise ValueError(
f"Expected output variables that were not found: {missing_vars}."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
@ -63,5 +69,3 @@ class Pipeline(Chain, BaseModel):
outputs = chain(known_values)
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}

View File

@ -1,8 +1,10 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
from langchain.chains.base import Chain
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from typing import List, Dict
from langchain.chains.base import Chain
class SimplePipeline(Chain, BaseModel):
@ -55,7 +57,5 @@ class SimplePipeline(Chain, BaseModel):
for chain in self.chains:
_input = chain.run(_input)
# Clean the input
_input = _input.strip(' \t\n\r')
_input = _input.strip()
return {self.output_key: _input}

View File

@ -1,9 +1,10 @@
from typing import Dict, List
from langchain.chains.pipeline import Pipeline
from langchain.chains.base import Chain
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.pipeline import Pipeline
class FakeChain(Chain, BaseModel):