diff --git a/langchain/__init__.py b/langchain/__init__.py index d604c92360..439cba0804 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -17,7 +17,7 @@ from langchain.chains import ( from langchain.docstore import Wikipedia from langchain.faiss import FAISS from langchain.llms import Cohere, HuggingFaceHub, OpenAI -from langchain.prompt import BasePrompt, Prompt +from langchain.prompt import BasePrompt, DynamicPrompt, Prompt from langchain.sql_database import SQLDatabase __all__ = [ @@ -29,6 +29,7 @@ __all__ = [ "Cohere", "OpenAI", "BasePrompt", + "DynamicPrompt", "Prompt", "ReActChain", "Wikipedia", diff --git a/langchain/prompt.py b/langchain/prompt.py index bec974286a..ac63eb1609 100644 --- a/langchain/prompt.py +++ b/langchain/prompt.py @@ -1,6 +1,7 @@ """Prompt schema definition.""" +import re from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from pydantic import BaseModel, Extra, root_validator @@ -126,3 +127,108 @@ class Prompt(BaseModel, BasePrompt): example_str = example_separator.join(examples) template = prefix + example_str + suffix return cls(input_variables=input_variables, template=template) + + +class DynamicPrompt(BaseModel, BasePrompt): + r"""Schema to represent a dynamic prompt for an LLM. + + Example: + .. code-block:: python + + from langchain import DynamicPrompt + dynamic_prompt = DynamicPrompt( + examples=["Say hi. Hi", "Say ho. Ho"], + example_separator="\n\n", + prefix="", + suffix="\n\nSay {foo}" + input_variables=["foo"], + max_length=200, + get_text_length=word_count + ) + """ + + examples: List[str] + """A list of the examples that the prompt template expects.""" + + example_separator: str = "\n\n" + """Example separator, e.g. \n\n, for the dynamic prompt creation.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + prefix: str + """Prefix for the prompt.""" + + suffix: str + """Suffix for the prompt.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string'.""" + + get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x)) + """Function to measure prompt length. Defaults to word count.""" + + max_length: int = 2048 + """Max length for the prompt, beyond which examples are cut.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def template(self, example_list: List[str], **kwargs: Any) -> str: + """Return template given example list.""" + template = self.example_separator.join( + [self.prefix, *example_list, self.suffix] + ) + return _FORMATTER_MAPPING[self.template_format](template, **kwargs) + + def format(self, **kwargs: Any) -> str: + """Dynamically format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + curr_examples = self.examples + template = self.template(curr_examples, **kwargs) + while self.get_text_length(template) > self.max_length and curr_examples: + curr_examples = curr_examples[:-1] + template = self.template(curr_examples, **kwargs) + return template + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that prefix, suffix and input variables are consistent.""" + input_variables = values["input_variables"] + suffix = values["suffix"] + template_format = values["template_format"] + if template_format not in _FORMATTER_MAPPING: + valid_formats = list(_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + try: + result = values["get_text_length"]("foo") + assert isinstance(result, int) + except AssertionError: + raise ValueError( + "Invalid text length callable, must take string & return int;" + ) + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + # TODO variables could be in prefix or suffix + try: + formatter_func = _FORMATTER_MAPPING[template_format] + formatter_func(suffix, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values diff --git a/tests/unit_tests/test_dynamic_prompt.py b/tests/unit_tests/test_dynamic_prompt.py new file mode 100644 index 0000000000..e4b4443bbc --- /dev/null +++ b/tests/unit_tests/test_dynamic_prompt.py @@ -0,0 +1,110 @@ +"""Test functionality related to dynamic prompts.""" +from langchain.prompt import DynamicPrompt, Prompt + +# FULL TEMPLATES +LONGER_TEMPLATE = """Test Prompt: + +Question: who are you? +Answer: foo + +Question: what are you? +Answer: bar + +Question: {question} +Answer:""" +SHORTER_TEMPLATE = """Test Prompt: + +Question: who are you? +Answer: foo + +Question: {question} +Answer:""" +SHORTEST_TEMPLATE = """Test Prompt: + +Question: {question} +Answer:""" + +# DYNAMIC PROMPT COMPONENTS +PREFIX = """Test Prompt:""" +SUFFIX = """Question: {question}\nAnswer:""" +EXAMPLES = [ + """Question: who are you?\nAnswer: foo""", + """Question: what are you?\nAnswer: bar""", +] + +# INPUTS +TEST_LONG_QUESTION = """I am writing a really long question, +this probably is going to affect the example right?""" +TEST_LONGEST_QUESTION = """This question is super super super, +super super super super super super super super super super super, +super super super super long, this will affect the example right?""" +TEST_SHORT_QUESTION = "Short question?" + + +def test_dynamic_prompt_valid() -> None: + """Test dynamic prompt can be successfully constructed from examples.""" + input_variables = ["question"] + example_separator = "\n\n" + dynamic_prompt_cls = DynamicPrompt( + examples=EXAMPLES, + suffix=SUFFIX, + input_variables=input_variables, + example_separator=example_separator, + prefix=PREFIX, + ) + prompt_cls = Prompt(input_variables=input_variables, template=LONGER_TEMPLATE) + dynamic_prompt_template = dynamic_prompt_cls.format(question="foo?") + prompt_template = prompt_cls.format(question="foo?") + assert dynamic_prompt_template == prompt_template + assert dynamic_prompt_cls.input_variables == prompt_cls.input_variables + + +def test_dynamic_prompt_trims_one_example() -> None: + """Test dynamic prompt can trim one example.""" + input_variables = ["question"] + example_separator = "\n\n" + dynamic_prompt_cls = DynamicPrompt( + examples=EXAMPLES, + suffix=SUFFIX, + input_variables=input_variables, + example_separator=example_separator, + prefix=PREFIX, + max_length=30, + ) + dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONG_QUESTION) + shorter_prompt = SHORTER_TEMPLATE.format(question=TEST_LONG_QUESTION) + assert dynamic_prompt == shorter_prompt + + +def test_dynamic_prompt_trims_no_examples() -> None: + """Test dynamic prompt can trim no examples.""" + input_variables = ["question"] + example_separator = "\n\n" + dynamic_prompt_cls = DynamicPrompt( + examples=EXAMPLES, + suffix=SUFFIX, + input_variables=input_variables, + example_separator=example_separator, + prefix=PREFIX, + max_length=30, + ) + dynamic_prompt = dynamic_prompt_cls.format(question=TEST_SHORT_QUESTION) + full_prompt = LONGER_TEMPLATE.format(question=TEST_SHORT_QUESTION) + assert dynamic_prompt == full_prompt + + +def test_dynamic_prompt_trims_all_examples() -> None: + """Test dynamic prompt can trim all examples.""" + input_variables = ["question"] + example_separator = "\n\n" + dynamic_prompt_cls = DynamicPrompt( + examples=EXAMPLES, + suffix=SUFFIX, + input_variables=input_variables, + example_separator=example_separator, + prefix=PREFIX, + max_length=30, + ) + dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONGEST_QUESTION) + full_prompt = SHORTEST_TEMPLATE.format(question=TEST_LONGEST_QUESTION) + assert dynamic_prompt == full_prompt