DynamicPrompt class creation (#49)

Checking that this structure looks generally ok -- going to sub in logic
where the TODO comment is then add a test.
This commit is contained in:
Samantha Whitmore 2022-11-05 12:43:21 -07:00 committed by GitHub
parent 618611f4dd
commit c636488fe5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 219 additions and 2 deletions

View File

@ -17,7 +17,7 @@ from langchain.chains import (
from langchain.docstore import Wikipedia from langchain.docstore import Wikipedia
from langchain.faiss import FAISS from langchain.faiss import FAISS
from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import BasePrompt, Prompt from langchain.prompt import BasePrompt, DynamicPrompt, Prompt
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
__all__ = [ __all__ = [
@ -29,6 +29,7 @@ __all__ = [
"Cohere", "Cohere",
"OpenAI", "OpenAI",
"BasePrompt", "BasePrompt",
"DynamicPrompt",
"Prompt", "Prompt",
"ReActChain", "ReActChain",
"Wikipedia", "Wikipedia",

View File

@ -1,6 +1,7 @@
"""Prompt schema definition.""" """Prompt schema definition."""
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Callable, Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -126,3 +127,108 @@ class Prompt(BaseModel, BasePrompt):
example_str = example_separator.join(examples) example_str = example_separator.join(examples)
template = prefix + example_str + suffix template = prefix + example_str + suffix
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template)
class DynamicPrompt(BaseModel, BasePrompt):
r"""Schema to represent a dynamic prompt for an LLM.
Example:
.. code-block:: python
from langchain import DynamicPrompt
dynamic_prompt = DynamicPrompt(
examples=["Say hi. Hi", "Say ho. Ho"],
example_separator="\n\n",
prefix="",
suffix="\n\nSay {foo}"
input_variables=["foo"],
max_length=200,
get_text_length=word_count
)
"""
examples: List[str]
"""A list of the examples that the prompt template expects."""
example_separator: str = "\n\n"
"""Example separator, e.g. \n\n, for the dynamic prompt creation."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
prefix: str
"""Prefix for the prompt."""
suffix: str
"""Suffix for the prompt."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'."""
get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x))
"""Function to measure prompt length. Defaults to word count."""
max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def template(self, example_list: List[str], **kwargs: Any) -> str:
"""Return template given example list."""
template = self.example_separator.join(
[self.prefix, *example_list, self.suffix]
)
return _FORMATTER_MAPPING[self.template_format](template, **kwargs)
def format(self, **kwargs: Any) -> str:
"""Dynamically format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
curr_examples = self.examples
template = self.template(curr_examples, **kwargs)
while self.get_text_length(template) > self.max_length and curr_examples:
curr_examples = curr_examples[:-1]
template = self.template(curr_examples, **kwargs)
return template
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
input_variables = values["input_variables"]
suffix = values["suffix"]
template_format = values["template_format"]
if template_format not in _FORMATTER_MAPPING:
valid_formats = list(_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
result = values["get_text_length"]("foo")
assert isinstance(result, int)
except AssertionError:
raise ValueError(
"Invalid text length callable, must take string & return int;"
)
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
# TODO variables could be in prefix or suffix
try:
formatter_func = _FORMATTER_MAPPING[template_format]
formatter_func(suffix, **dummy_inputs)
except KeyError:
raise ValueError("Invalid prompt schema.")
return values

View File

@ -0,0 +1,110 @@
"""Test functionality related to dynamic prompts."""
from langchain.prompt import DynamicPrompt, Prompt
# FULL TEMPLATES
LONGER_TEMPLATE = """Test Prompt:
Question: who are you?
Answer: foo
Question: what are you?
Answer: bar
Question: {question}
Answer:"""
SHORTER_TEMPLATE = """Test Prompt:
Question: who are you?
Answer: foo
Question: {question}
Answer:"""
SHORTEST_TEMPLATE = """Test Prompt:
Question: {question}
Answer:"""
# DYNAMIC PROMPT COMPONENTS
PREFIX = """Test Prompt:"""
SUFFIX = """Question: {question}\nAnswer:"""
EXAMPLES = [
"""Question: who are you?\nAnswer: foo""",
"""Question: what are you?\nAnswer: bar""",
]
# INPUTS
TEST_LONG_QUESTION = """I am writing a really long question,
this probably is going to affect the example right?"""
TEST_LONGEST_QUESTION = """This question is super super super,
super super super super super super super super super super super,
super super super super long, this will affect the example right?"""
TEST_SHORT_QUESTION = "Short question?"
def test_dynamic_prompt_valid() -> None:
"""Test dynamic prompt can be successfully constructed from examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
)
prompt_cls = Prompt(input_variables=input_variables, template=LONGER_TEMPLATE)
dynamic_prompt_template = dynamic_prompt_cls.format(question="foo?")
prompt_template = prompt_cls.format(question="foo?")
assert dynamic_prompt_template == prompt_template
assert dynamic_prompt_cls.input_variables == prompt_cls.input_variables
def test_dynamic_prompt_trims_one_example() -> None:
"""Test dynamic prompt can trim one example."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONG_QUESTION)
shorter_prompt = SHORTER_TEMPLATE.format(question=TEST_LONG_QUESTION)
assert dynamic_prompt == shorter_prompt
def test_dynamic_prompt_trims_no_examples() -> None:
"""Test dynamic prompt can trim no examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_SHORT_QUESTION)
full_prompt = LONGER_TEMPLATE.format(question=TEST_SHORT_QUESTION)
assert dynamic_prompt == full_prompt
def test_dynamic_prompt_trims_all_examples() -> None:
"""Test dynamic prompt can trim all examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONGEST_QUESTION)
full_prompt = SHORTEST_TEMPLATE.format(question=TEST_LONGEST_QUESTION)
assert dynamic_prompt == full_prompt