|
|
|
@ -5,7 +5,14 @@ import pytest
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.chains.sequential import (
|
|
|
|
|
SequentialChain,
|
|
|
|
|
SimpleSequentialChain,
|
|
|
|
|
construct_sequential_llm_chain,
|
|
|
|
|
)
|
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeChain(Chain, BaseModel):
|
|
|
|
@ -138,3 +145,21 @@ def test_multi_output_errors() -> None:
|
|
|
|
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
SimpleSequentialChain(chains=[chain_1, chain_2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_construct_sequential_llm_chain() -> None:
|
|
|
|
|
"""Test constructing simple sequential chain."""
|
|
|
|
|
prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"])
|
|
|
|
|
llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar")
|
|
|
|
|
add_ons = [("{bar} and what does it do?", ["bar"], "baz")]
|
|
|
|
|
chain = construct_sequential_llm_chain(llm_chain, add_ons)
|
|
|
|
|
|
|
|
|
|
expected_new_prompt = PromptTemplate(
|
|
|
|
|
template="what is {foo}?{bar} and what does it do?",
|
|
|
|
|
input_variables=["foo", "bar"],
|
|
|
|
|
)
|
|
|
|
|
expected_new_chain = LLMChain(
|
|
|
|
|
llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz"
|
|
|
|
|
)
|
|
|
|
|
expected_chains = [llm_chain, expected_new_chain]
|
|
|
|
|
assert chain.chains == expected_chains
|
|
|
|
|