forked from Archives/langchain
Harrison/sequential chains (#168)
add support for basic sequential chains
This commit is contained in:
parent
15c19fcc60
commit
4a4dfbfbed
265
docs/examples/demos/sequential_chains.ipynb
Normal file
265
docs/examples/demos/sequential_chains.ipynb
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4f73605d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Sequential Chains"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3b235f7a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The next step after calling a language model is make a series of calls to a language model. This is particularly useful when you want to take the output from one call and use it as the input to another.\n",
|
||||||
|
"\n",
|
||||||
|
"In this notebook we will walk through some examples for how to do this, using sequential chains. Sequential chains are defined as a series of chains, called in deterministic order. There are two types of sequential chains:\n",
|
||||||
|
"\n",
|
||||||
|
"- `SimpleSequentialChain`: The simplest form of sequential chains, where each step has a singular input/output, and the output of one step is the input to the next.\n",
|
||||||
|
"- `SequentialChain`: A more general form of sequential chains, allowing for multiple inputs/outputs."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5162794e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## SimpleSequentialChain\n",
|
||||||
|
"\n",
|
||||||
|
"In this series of chains, each individual chain has a single input and a single output, and the output of one step is used as input to the next.\n",
|
||||||
|
"\n",
|
||||||
|
"Let's walk through a toy example of doing this, where the first chain takes in the title of an imaginary play and then generates a synopsis for that title, and the second chain takes in the synopsis of that play and generates an imaginary review for that play."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "3f2f9b8c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.chains import LLMChain\n",
|
||||||
|
"from langchain.prompts import PromptTemplate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "b8237d1a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is an LLMChain to write a synopsis given a title of a play.\n",
|
||||||
|
"llm = OpenAI(temperature=.7)\n",
|
||||||
|
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||||
|
"\n",
|
||||||
|
"Title: {title}\n",
|
||||||
|
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||||
|
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||||
|
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "4a391730",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is an LLMChain to write a review of a play given a synopsis.\n",
|
||||||
|
"llm = OpenAI(temperature=.7)\n",
|
||||||
|
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
|
||||||
|
"\n",
|
||||||
|
"Play Synopsis:\n",
|
||||||
|
"{synopsis}\n",
|
||||||
|
"Review from a New York Times play critic of the above play:\"\"\"\n",
|
||||||
|
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
|
||||||
|
"review_chain = LLMChain(llm=llm, prompt=prompt_template)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "9368bd63",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is the overall chain where we run these two chains in sequence.\n",
|
||||||
|
"from langchain.chains import SimpleSequentialChain\n",
|
||||||
|
"overall_chain = SimpleSequentialChain(chains=[synopsis_chain, review_chain], verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "d39e15f5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
|
"\n",
|
||||||
|
"A young couple, John and Mary, are enjoying a day at the beach. As the sun sets, they share a romantic moment. However, their happiness is short-lived, as a tragic accident claims John's life. Mary is left devastated by the loss of her husband.\u001b[0m\n",
|
||||||
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
|
"\n",
|
||||||
|
"\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"review = overall_chain.run(\"Tragedy at sunset on the beach\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "c6649a01",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(review)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c3f1549a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Sequential Chain\n",
|
||||||
|
"Of course, not all sequential chains will be as simple as passing a single string as an argument and getting a single string as output for all steps in the chain. In this next example, we will experiment with more complex chains that involve multiple inputs, and where there also multiple final outputs. \n",
|
||||||
|
"\n",
|
||||||
|
"Of particular importance is how we name the input/output variable names. In the above example we didn't have to think about that because we were just passing the output of one chain directly as input to the next, but here we do have worry about that because we have multiple inputs."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "02016a51",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is an LLMChain to write a synopsis given a title of a play and the era it is set in.\n",
|
||||||
|
"llm = OpenAI(temperature=.7)\n",
|
||||||
|
"template = \"\"\"You are a playwright. Given the title of play and the era it is set in, it is your job to write a synopsis for that title.\n",
|
||||||
|
"\n",
|
||||||
|
"Title: {title}\n",
|
||||||
|
"Era: {era}\n",
|
||||||
|
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||||
|
"prompt_template = PromptTemplate(input_variables=[\"title\", 'era'], template=template)\n",
|
||||||
|
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"synopsis\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "8bd38cc2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is an LLMChain to write a review of a play given a synopsis.\n",
|
||||||
|
"llm = OpenAI(temperature=.7)\n",
|
||||||
|
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
|
||||||
|
"\n",
|
||||||
|
"Play Synopsis:\n",
|
||||||
|
"{synopsis}\n",
|
||||||
|
"Review from a New York Times play critic of the above play:\"\"\"\n",
|
||||||
|
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
|
||||||
|
"review_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"review\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "524523af",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is the overall chain where we run these two chains in sequence.\n",
|
||||||
|
"from langchain.chains import SequentialChain\n",
|
||||||
|
"overall_chain = SequentialChain(\n",
|
||||||
|
" chains=[synopsis_chain, review_chain],\n",
|
||||||
|
" input_variables=[\"era\", \"title\"],\n",
|
||||||
|
" # Here we return multiple variables\n",
|
||||||
|
" output_variables=[\"synopsis\", \"review\"],\n",
|
||||||
|
" verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "3fd3a7be",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"\u001b[1mChain 0\u001b[0m:\n",
|
||||||
|
"{'synopsis': \"\\n\\nThe play is set in Victorian England and follows the tragic story of a young woman who drowns while swimming at sunset on the beach. Her body is found the next morning by a fisherman who raises the alarm. The young woman's family and friends are devastated by her death and the play ends with their mourning her loss.\"}\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1mChain 1\u001b[0m:\n",
|
||||||
|
"{'review': '\\n\\n\"The play is a tragedy, pure and simple. It is the story of a young woman\\'s death, told through the eyes of those who loved her. It is a sad, beautiful play that will stay with you long after you\\'ve seen it. The acting is superb, and the writing is exquisite. If you are looking for a play that will touch your heart and make you think, this is it.\"'}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"review = overall_chain({\"title\":\"Tragedy at sunset on the beach\", \"era\": \"Victorian England\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6be70d27",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -5,6 +5,7 @@ from langchain.chains.mrkl.base import MRKLChain
|
|||||||
from langchain.chains.python import PythonChain
|
from langchain.chains.python import PythonChain
|
||||||
from langchain.chains.react.base import ReActChain
|
from langchain.chains.react.base import ReActChain
|
||||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||||
@ -19,4 +20,6 @@ __all__ = [
|
|||||||
"SQLDatabaseChain",
|
"SQLDatabaseChain",
|
||||||
"MRKLChain",
|
"MRKLChain",
|
||||||
"VectorDBQA",
|
"VectorDBQA",
|
||||||
|
"SequentialChain",
|
||||||
|
"SimpleSequentialChain",
|
||||||
]
|
]
|
||||||
|
@ -38,8 +38,19 @@ class Chain(BaseModel, ABC):
|
|||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Run the logic of this chain and return the output."""
|
"""Run the logic of this chain and return the output."""
|
||||||
|
|
||||||
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def __call__(
|
||||||
"""Run the logic of this chain and add to output."""
|
self, inputs: Dict[str, Any], return_only_outputs: bool = False
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Run the logic of this chain and add to output if desired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dictionary of inputs.
|
||||||
|
return_only_outputs: boolean for whether to return only outputs in the
|
||||||
|
response. If True, only new keys generated by this chain will be
|
||||||
|
returned. If False, both input keys and new keys generated by this
|
||||||
|
chain will be returned. Defaults to False.
|
||||||
|
|
||||||
|
"""
|
||||||
self._validate_inputs(inputs)
|
self._validate_inputs(inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||||
@ -47,7 +58,10 @@ class Chain(BaseModel, ABC):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n\033[1m> Finished chain.\033[0m")
|
print("\n\033[1m> Finished chain.\033[0m")
|
||||||
self._validate_outputs(outputs)
|
self._validate_outputs(outputs)
|
||||||
return {**inputs, **outputs}
|
if return_only_outputs:
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||||
"""Call the chain on all inputs in the list."""
|
"""Call the chain on all inputs in the list."""
|
||||||
|
137
langchain/chains/sequential.py
Normal file
137
langchain/chains/sequential.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.input import get_color_mapping, print_text
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialChain(Chain, BaseModel):
|
||||||
|
"""Chain where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
chains: List[Chain]
|
||||||
|
input_variables: List[str]
|
||||||
|
output_variables: List[str] #: :meta private:
|
||||||
|
return_all: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return self.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return self.output_variables
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_chains(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that the correct inputs exist for all chains."""
|
||||||
|
chains = values["chains"]
|
||||||
|
input_variables = values["input_variables"]
|
||||||
|
known_variables = set(input_variables)
|
||||||
|
for chain in chains:
|
||||||
|
missing_vars = set(chain.input_keys).difference(known_variables)
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(f"Missing required input keys: {missing_vars}")
|
||||||
|
overlapping_keys = known_variables.intersection(chain.output_keys)
|
||||||
|
if overlapping_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Chain returned keys that already exist: {overlapping_keys}"
|
||||||
|
)
|
||||||
|
known_variables |= set(chain.output_keys)
|
||||||
|
|
||||||
|
if "output_variables" not in values:
|
||||||
|
if values.get("return_all", False):
|
||||||
|
output_keys = known_variables.difference(input_variables)
|
||||||
|
else:
|
||||||
|
output_keys = chains[-1].output_keys
|
||||||
|
values["output_variables"] = output_keys
|
||||||
|
else:
|
||||||
|
missing_vars = set(values["output_variables"]).difference(known_variables)
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected output variables that were not found: {missing_vars}."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
known_values = inputs.copy()
|
||||||
|
for i, chain in enumerate(self.chains):
|
||||||
|
outputs = chain(known_values, return_only_outputs=True)
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\033[1mChain {i}\033[0m:\n{outputs}\n")
|
||||||
|
known_values.update(outputs)
|
||||||
|
return {k: known_values[k] for k in self.output_variables}
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSequentialChain(Chain, BaseModel):
|
||||||
|
"""Simple chain where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
chains: List[Chain]
|
||||||
|
strip_outputs: bool = False
|
||||||
|
input_key: str = "input" #: :meta private:
|
||||||
|
output_key: str = "output" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_chains(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that chains are all single input/output."""
|
||||||
|
for chain in values["chains"]:
|
||||||
|
if len(chain.input_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chains used in SimplePipeline should all have one input, got "
|
||||||
|
f"{chain} with {len(chain.input_keys)} inputs."
|
||||||
|
)
|
||||||
|
if len(chain.output_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chains used in SimplePipeline should all have one output, got "
|
||||||
|
f"{chain} with {len(chain.output_keys)} outputs."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
_input = inputs[self.input_key]
|
||||||
|
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||||
|
for i, chain in enumerate(self.chains):
|
||||||
|
_input = chain.run(_input)
|
||||||
|
if self.strip_outputs:
|
||||||
|
_input = _input.strip()
|
||||||
|
if self.verbose:
|
||||||
|
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||||
|
return {self.output_key: _input}
|
140
tests/unit_tests/chains/test_sequential.py
Normal file
140
tests/unit_tests/chains/test_sequential.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""Test pipeline functionality."""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChain(Chain, BaseModel):
|
||||||
|
"""Fake Chain for testing purposes."""
|
||||||
|
|
||||||
|
input_variables: List[str]
|
||||||
|
output_variables: List[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys this chain returns."""
|
||||||
|
return self.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Input keys this chain returns."""
|
||||||
|
return self.output_variables
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
outputs = {}
|
||||||
|
for var in self.output_variables:
|
||||||
|
variables = [inputs[k] for k in self.input_variables]
|
||||||
|
outputs[var] = " ".join(variables) + "foo"
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_usage_single_inputs() -> None:
|
||||||
|
"""Test sequential on single input chains."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||||
|
output = chain({"foo": "123"})
|
||||||
|
expected_output = {"baz": "123foofoo", "foo": "123"}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_usage_multiple_inputs() -> None:
|
||||||
|
"""Test sequential on multiple input chains."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||||
|
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||||
|
output = chain({"foo": "123", "test": "456"})
|
||||||
|
expected_output = {
|
||||||
|
"baz": "123 456foo 123foo",
|
||||||
|
"foo": "123",
|
||||||
|
"test": "456",
|
||||||
|
}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_usage_multiple_outputs() -> None:
|
||||||
|
"""Test sequential usage on multiple output chains."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||||
|
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||||
|
output = chain({"foo": "123"})
|
||||||
|
expected_output = {
|
||||||
|
"baz": "123foo 123foo",
|
||||||
|
"foo": "123",
|
||||||
|
}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_missing_inputs() -> None:
|
||||||
|
"""Test error is raised when input variables are missing."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# Also needs "test" as an input
|
||||||
|
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_bad_outputs() -> None:
|
||||||
|
"""Test error is raised when bad outputs are specified."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# "test" is not present as an output variable.
|
||||||
|
SequentialChain(
|
||||||
|
chains=[chain_1, chain_2],
|
||||||
|
input_variables=["foo"],
|
||||||
|
output_variables=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_valid_outputs() -> None:
|
||||||
|
"""Test chain runs when valid outputs are specified."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
chain = SequentialChain(
|
||||||
|
chains=[chain_1, chain_2],
|
||||||
|
input_variables=["foo"],
|
||||||
|
output_variables=["bar", "baz"],
|
||||||
|
)
|
||||||
|
output = chain({"foo": "123"}, return_only_outputs=True)
|
||||||
|
expected_output = {"baz": "123foofoo", "bar": "123foo"}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_overlapping_inputs() -> None:
|
||||||
|
"""Test error is raised when input variables are overlapping."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# "test" is specified as an input, but also is an output of one step
|
||||||
|
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_sequential_functionality() -> None:
|
||||||
|
"""Test simple sequential functionality."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
chain = SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||||
|
output = chain({"input": "123"})
|
||||||
|
expected_output = {"output": "123foofoo", "input": "123"}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_input_errors() -> None:
|
||||||
|
"""Test simple sequential errors if multiple input variables are expected."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_output_errors() -> None:
|
||||||
|
"""Test simple sequential errors if multiple output variables are expected."""
|
||||||
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SimpleSequentialChain(chains=[chain_1, chain_2])
|
Loading…
Reference in New Issue
Block a user