forked from Archives/langchain
cr
This commit is contained in:
parent
3fcc803880
commit
620484f3ea
@ -1,9 +1,11 @@
|
|||||||
"""Chain that generates a list and then maps each output to another chain."""
|
"""Chain that generates a list and then maps each output to another chain."""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class MapChain(Chain, BaseModel):
|
class MapChain(Chain, BaseModel):
|
||||||
@ -70,4 +72,3 @@ class MapChain(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
outputs = {self.map_chain.run(text) for text in new_inputs}
|
outputs = {self.map_chain.run(text) for text in new_inputs}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline(Chain, BaseModel):
|
class Pipeline(Chain, BaseModel):
|
||||||
@ -36,7 +38,7 @@ class Pipeline(Chain, BaseModel):
|
|||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_chains(cls, values: Dict) -> Dict:
|
def validate_chains(cls, values: Dict) -> Dict:
|
||||||
"""Validate that chains are all single input/output."""
|
"""Validate that the correct inputs exist for all chains."""
|
||||||
chains = values["chains"]
|
chains = values["chains"]
|
||||||
input_variables = values["input_variables"]
|
input_variables = values["input_variables"]
|
||||||
known_variables = set(input_variables)
|
known_variables = set(input_variables)
|
||||||
@ -46,7 +48,9 @@ class Pipeline(Chain, BaseModel):
|
|||||||
raise ValueError(f"Missing required input keys: {missing_vars}")
|
raise ValueError(f"Missing required input keys: {missing_vars}")
|
||||||
overlapping_keys = known_variables.intersection(chain.output_keys)
|
overlapping_keys = known_variables.intersection(chain.output_keys)
|
||||||
if overlapping_keys:
|
if overlapping_keys:
|
||||||
raise ValueError(f"Chain returned keys that already exist: {overlapping_keys}")
|
raise ValueError(
|
||||||
|
f"Chain returned keys that already exist: {overlapping_keys}"
|
||||||
|
)
|
||||||
known_variables |= set(chain.output_keys)
|
known_variables |= set(chain.output_keys)
|
||||||
|
|
||||||
if "output_variables" not in values:
|
if "output_variables" not in values:
|
||||||
@ -54,7 +58,9 @@ class Pipeline(Chain, BaseModel):
|
|||||||
else:
|
else:
|
||||||
missing_vars = known_variables.difference(values["output_variables"])
|
missing_vars = known_variables.difference(values["output_variables"])
|
||||||
if missing_vars:
|
if missing_vars:
|
||||||
raise ValueError(f"Expected output variables that were not found: {missing_vars}.")
|
raise ValueError(
|
||||||
|
f"Expected output variables that were not found: {missing_vars}."
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
@ -63,5 +69,3 @@ class Pipeline(Chain, BaseModel):
|
|||||||
outputs = chain(known_values)
|
outputs = chain(known_values)
|
||||||
known_values.update(outputs)
|
known_values.update(outputs)
|
||||||
return {k: known_values[k] for k in self.output_variables}
|
return {k: known_values[k] for k in self.output_variables}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
from typing import List, Dict
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
class SimplePipeline(Chain, BaseModel):
|
class SimplePipeline(Chain, BaseModel):
|
||||||
@ -55,7 +57,5 @@ class SimplePipeline(Chain, BaseModel):
|
|||||||
for chain in self.chains:
|
for chain in self.chains:
|
||||||
_input = chain.run(_input)
|
_input = chain.run(_input)
|
||||||
# Clean the input
|
# Clean the input
|
||||||
_input = _input.strip(' \t\n\r')
|
_input = _input.strip()
|
||||||
return {self.output_key: _input}
|
return {self.output_key: _input}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from langchain.chains.pipeline import Pipeline
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
class FakeChain(Chain, BaseModel):
|
class FakeChain(Chain, BaseModel):
|
||||||
|
|
||||||
@ -30,4 +31,4 @@ def test_pipeline_usage() -> None:
|
|||||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
|
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||||
output = pipeline({"foo": "123"})
|
output = pipeline({"foo": "123"})
|
||||||
breakpoint()
|
breakpoint()
|
||||||
|
Loading…
Reference in New Issue
Block a user