diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 1bcc723d..338ca319 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -1,10 +1,11 @@ """Chain pipeline where the outputs of one step feed directly into next.""" -from typing import Dict, List +from typing import Dict, List, Tuple from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping, print_text @@ -135,3 +136,18 @@ class SimpleSequentialChain(Chain, BaseModel): if self.verbose: print_text(_input, color=color_mapping[str(i)], end="\n") return {self.output_key: _input} + + +def construct_sequential_llm_chain( + llm_chain: LLMChain, add_ons: List[Tuple[str, List[str], str]] +) -> SequentialChain: + base_prompt = llm_chain.prompt + chains = [llm_chain] + for template, input_vars, output_key in add_ons: + new_prompt = base_prompt.extend_prompt(template, input_vars) + new_llm_chain = LLMChain( + llm=llm_chain.llm, prompt=new_prompt, output_key=output_key + ) + chains.append(new_llm_chain) + + return SequentialChain(chains=chains, input_variables=llm_chain.input_keys) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index b5ba37eb..2582557c 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,4 +1,6 @@ """BasePrompt schema definition.""" +from __future__ import annotations + import json from abc import ABC, abstractmethod from pathlib import Path @@ -62,6 +64,12 @@ class BasePromptTemplate(BaseModel, ABC): extra = Extra.forbid arbitrary_types_allowed = True + @abstractmethod + def extend_prompt( + self, template: str, input_variables: List[str] + ) -> BasePromptTemplate: + """Extend the prompt with another template/input variables.""" + @root_validator() def validate_variable_names(cls, values: Dict) -> Dict: """Validate variable names do not restricted names.""" diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 98dffd65..6219ea6e 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -1,4 +1,6 @@ """Prompt template that contains few shot examples.""" +from __future__ import annotations + from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator @@ -41,6 +43,20 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string'.""" + def extend_prompt( + self, template: str, input_variables: List[str] + ) -> FewShotPromptTemplate: + """Append to template and input variables.""" + copied_prompt = self.copy(deep=True) + copied_prompt.suffix += template + copied_prompt.input_variables += input_variables + check_valid_template( + copied_prompt.prefix + copied_prompt.suffix, + copied_prompt.template_format, + copied_prompt.input_variables, + ) + return copied_prompt + @root_validator(pre=True) def check_examples_and_selector(cls, values: Dict) -> Dict: """Check that one and only one of examples/example_selector are provided.""" diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 4a24c2de..e4b0b918 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -36,6 +36,20 @@ class PromptTemplate(BasePromptTemplate, BaseModel): extra = Extra.forbid + def extend_prompt( + self, template: str, input_variables: List[str] + ) -> PromptTemplate: + """Append to template and input variables.""" + copied_prompt = self.copy(deep=True) + copied_prompt.template += template + copied_prompt.input_variables += input_variables + check_valid_template( + copied_prompt.template, + copied_prompt.template_format, + copied_prompt.input_variables, + ) + return copied_prompt + def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs. diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index f231a740..f2a6cd02 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -5,7 +5,14 @@ import pytest from pydantic import BaseModel from langchain.chains.base import Chain -from langchain.chains.sequential import SequentialChain, SimpleSequentialChain +from langchain.chains.llm import LLMChain +from langchain.chains.sequential import ( + SequentialChain, + SimpleSequentialChain, + construct_sequential_llm_chain, +) +from langchain.prompts.prompt import PromptTemplate +from tests.unit_tests.llms.fake_llm import FakeLLM class FakeChain(Chain, BaseModel): @@ -138,3 +145,21 @@ def test_multi_output_errors() -> None: chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) with pytest.raises(ValueError): SimpleSequentialChain(chains=[chain_1, chain_2]) + + +def test_construct_sequential_llm_chain() -> None: + """Test constructing simple sequential chain.""" + prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"]) + llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar") + add_ons = [("{bar} and what does it do?", ["bar"], "baz")] + chain = construct_sequential_llm_chain(llm_chain, add_ons) + + expected_new_prompt = PromptTemplate( + template="what is {foo}?{bar} and what does it do?", + input_variables=["foo", "bar"], + ) + expected_new_chain = LLMChain( + llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz" + ) + expected_chains = [llm_chain, expected_new_chain] + assert chain.chains == expected_chains