Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase 6dcbb74582 sequential chain from prompts 1 year ago

@ -1,10 +1,11 @@
"""Chain pipeline where the outputs of one step feed directly into next.""" """Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List from typing import Dict, List, Tuple
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text from langchain.input import get_color_mapping, print_text
@ -135,3 +136,18 @@ class SimpleSequentialChain(Chain, BaseModel):
if self.verbose: if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n") print_text(_input, color=color_mapping[str(i)], end="\n")
return {self.output_key: _input} return {self.output_key: _input}
def construct_sequential_llm_chain(
llm_chain: LLMChain, add_ons: List[Tuple[str, List[str], str]]
) -> SequentialChain:
base_prompt = llm_chain.prompt
chains = [llm_chain]
for template, input_vars, output_key in add_ons:
new_prompt = base_prompt.extend_prompt(template, input_vars)
new_llm_chain = LLMChain(
llm=llm_chain.llm, prompt=new_prompt, output_key=output_key
)
chains.append(new_llm_chain)
return SequentialChain(chains=chains, input_variables=llm_chain.input_keys)

@ -1,4 +1,6 @@
"""BasePrompt schema definition.""" """BasePrompt schema definition."""
from __future__ import annotations
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
@ -62,6 +64,12 @@ class BasePromptTemplate(BaseModel, ABC):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@abstractmethod
def extend_prompt(
self, template: str, input_variables: List[str]
) -> BasePromptTemplate:
"""Extend the prompt with another template/input variables."""
@root_validator() @root_validator()
def validate_variable_names(cls, values: Dict) -> Dict: def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not restricted names.""" """Validate variable names do not restricted names."""

@ -1,4 +1,6 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -41,6 +43,20 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'.""" """The format of the prompt template. Options are: 'f-string'."""
def extend_prompt(
self, template: str, input_variables: List[str]
) -> FewShotPromptTemplate:
"""Append to template and input variables."""
copied_prompt = self.copy(deep=True)
copied_prompt.suffix += template
copied_prompt.input_variables += input_variables
check_valid_template(
copied_prompt.prefix + copied_prompt.suffix,
copied_prompt.template_format,
copied_prompt.input_variables,
)
return copied_prompt
@root_validator(pre=True) @root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict: def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided.""" """Check that one and only one of examples/example_selector are provided."""

@ -36,6 +36,20 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
extra = Extra.forbid extra = Extra.forbid
def extend_prompt(
self, template: str, input_variables: List[str]
) -> PromptTemplate:
"""Append to template and input variables."""
copied_prompt = self.copy(deep=True)
copied_prompt.template += template
copied_prompt.input_variables += input_variables
check_valid_template(
copied_prompt.template,
copied_prompt.template_format,
copied_prompt.input_variables,
)
return copied_prompt
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs. """Format the prompt with the inputs.

@ -5,7 +5,14 @@ import pytest
from pydantic import BaseModel from pydantic import BaseModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.llm import LLMChain
from langchain.chains.sequential import (
SequentialChain,
SimpleSequentialChain,
construct_sequential_llm_chain,
)
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeChain(Chain, BaseModel): class FakeChain(Chain, BaseModel):
@ -138,3 +145,21 @@ def test_multi_output_errors() -> None:
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError): with pytest.raises(ValueError):
SimpleSequentialChain(chains=[chain_1, chain_2]) SimpleSequentialChain(chains=[chain_1, chain_2])
def test_construct_sequential_llm_chain() -> None:
"""Test constructing simple sequential chain."""
prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"])
llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar")
add_ons = [("{bar} and what does it do?", ["bar"], "baz")]
chain = construct_sequential_llm_chain(llm_chain, add_ons)
expected_new_prompt = PromptTemplate(
template="what is {foo}?{bar} and what does it do?",
input_variables=["foo", "bar"],
)
expected_new_chain = LLMChain(
llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz"
)
expected_chains = [llm_chain, expected_new_chain]
assert chain.chains == expected_chains

Loading…
Cancel
Save