Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase 6dcbb74582 sequential chain from prompts 1 year ago

@ -1,10 +1,11 @@
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from typing import Dict, List, Tuple
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
@ -135,3 +136,18 @@ class SimpleSequentialChain(Chain, BaseModel):
if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n")
return {self.output_key: _input}
def construct_sequential_llm_chain(
llm_chain: LLMChain, add_ons: List[Tuple[str, List[str], str]]
) -> SequentialChain:
base_prompt = llm_chain.prompt
chains = [llm_chain]
for template, input_vars, output_key in add_ons:
new_prompt = base_prompt.extend_prompt(template, input_vars)
new_llm_chain = LLMChain(
llm=llm_chain.llm, prompt=new_prompt, output_key=output_key
)
chains.append(new_llm_chain)
return SequentialChain(chains=chains, input_variables=llm_chain.input_keys)

@ -1,4 +1,6 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
@ -62,6 +64,12 @@ class BasePromptTemplate(BaseModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@abstractmethod
def extend_prompt(
self, template: str, input_variables: List[str]
) -> BasePromptTemplate:
"""Extend the prompt with another template/input variables."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not restricted names."""

@ -1,4 +1,6 @@
"""Prompt template that contains few shot examples."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
@ -41,6 +43,20 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'."""
def extend_prompt(
self, template: str, input_variables: List[str]
) -> FewShotPromptTemplate:
"""Append to template and input variables."""
copied_prompt = self.copy(deep=True)
copied_prompt.suffix += template
copied_prompt.input_variables += input_variables
check_valid_template(
copied_prompt.prefix + copied_prompt.suffix,
copied_prompt.template_format,
copied_prompt.input_variables,
)
return copied_prompt
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""

@ -36,6 +36,20 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
extra = Extra.forbid
def extend_prompt(
self, template: str, input_variables: List[str]
) -> PromptTemplate:
"""Append to template and input variables."""
copied_prompt = self.copy(deep=True)
copied_prompt.template += template
copied_prompt.input_variables += input_variables
check_valid_template(
copied_prompt.template,
copied_prompt.template_format,
copied_prompt.input_variables,
)
return copied_prompt
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.

@ -5,7 +5,14 @@ import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import (
SequentialChain,
SimpleSequentialChain,
construct_sequential_llm_chain,
)
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeChain(Chain, BaseModel):
@ -138,3 +145,21 @@ def test_multi_output_errors() -> None:
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
SimpleSequentialChain(chains=[chain_1, chain_2])
def test_construct_sequential_llm_chain() -> None:
"""Test constructing simple sequential chain."""
prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"])
llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar")
add_ons = [("{bar} and what does it do?", ["bar"], "baz")]
chain = construct_sequential_llm_chain(llm_chain, add_ons)
expected_new_prompt = PromptTemplate(
template="what is {foo}?{bar} and what does it do?",
input_variables=["foo", "bar"],
)
expected_new_chain = LLMChain(
llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz"
)
expected_chains = [llm_chain, expected_new_chain]
assert chain.chains == expected_chains

Loading…
Cancel
Save