From a5b61d59e193c4f9804c9dece7bd1a1c5c1103e0 Mon Sep 17 00:00:00 2001 From: Samantha Whitmore Date: Sun, 6 Nov 2022 15:40:33 -0800 Subject: [PATCH] Refactor prompts into module, add example generation utils (#64) --- README.md | 4 +- examples/generate_examples.ipynb | 121 +++++++++ langchain/__init__.py | 2 +- langchain/chains/llm.py | 2 +- langchain/chains/llm_math/prompt.py | 2 +- langchain/chains/mapreduce.py | 2 +- langchain/chains/natbot/prompt.py | 14 +- langchain/chains/react/prompt.py | 2 +- .../chains/self_ask_with_search/prompt.py | 2 +- langchain/chains/sql_database/prompt.py | 4 +- langchain/example_generator.py | 20 ++ langchain/prompt.py | 234 ------------------ langchain/prompts/__init__.py | 6 + langchain/prompts/base.py | 33 +++ langchain/prompts/dynamic.py | 112 +++++++++ langchain/prompts/prompt.py | 99 ++++++++ tests/unit_tests/chains/test_llm.py | 2 +- tests/unit_tests/chains/test_react.py | 2 +- tests/unit_tests/test_dynamic_prompt.py | 3 +- tests/unit_tests/test_prompt.py | 2 +- 20 files changed, 414 insertions(+), 254 deletions(-) create mode 100644 examples/generate_examples.ipynb create mode 100644 langchain/example_generator.py delete mode 100644 langchain/prompt.py create mode 100644 langchain/prompts/__init__.py create mode 100644 langchain/prompts/base.py create mode 100644 langchain/prompts/dynamic.py create mode 100644 langchain/prompts/prompt.py diff --git a/README.md b/README.md index 3ae4cd78..1e3e34f8 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ but full API docs can be found [here](https://langchain.readthedocs.io/en/latest ## 🤖 Developer Guide To begin developing on this project, first clone to the repo locally. -To install requirements, run `pip install -r requirments.txt`. +To install requirements, run `pip install -r requirements.txt`. This will install all requirements for running the package, examples, linting, formatting, and tests. Formatting for this project is a combination of [Black](https://black.readthedocs.io/en/stable/) and [isort](https://pycqa.github.io/isort/). @@ -125,6 +125,8 @@ Integration tests cover logic that requires making calls to outside APIs (often To run integration tests, run `make integration_tests`. If you add support for a new external API, please add a new integration test. +If you are adding a Jupyter notebook example, you can run `pip install -e .` to build the langchain package from your local changes, so your new logic can be imported into the notebook. + Docs are largely autogenerated by [sphinx](https://www.sphinx-doc.org/en/master/) from the code. For that reason, we ask that you add good documentation to all classes and methods. Similar to linting, we recognize documentation can be annoying - if you do not want to do it, please contact a project maintainer and they can help you with it. We do not want this to be a blocker for good code getting contributed. diff --git a/examples/generate_examples.ipynb b/examples/generate_examples.ipynb new file mode 100644 index 00000000..2c334ba0 --- /dev/null +++ b/examples/generate_examples.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1685fa2f", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.react.prompt import EXAMPLES\n", + "from langchain.llms.openai import OpenAI\n", + "from langchain.example_generator import generate_example, generate_example_from_dynamic_prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "334ef4f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Question: What is the elevation range for the area that the eastern sector of the\\nColorado orogeny extends into?\\nThought 1: I need to search Colorado orogeny, find the area that the eastern sector\\nof the Colorado orogeny extends into, then find the elevation range of the\\narea.\\nAction 1: Search[Colorado orogeny]\\nObservation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in\\nColorado and surrounding areas.\\nThought 2: It does not mention the eastern sector. So I need to look up eastern\\nsector.\\nAction 2: Lookup[eastern sector]\\nObservation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called\\nthe Central Plains orogeny.\\nThought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I\\nneed to search High Plains and find its elevation range.\\nAction 3: Search[High Plains]\\nObservation 3: High Plains refers to one of two distinct land regions\\nThought 4: I need to instead search High Plains (United States).\\nAction 4: Search[High Plains (United States)]\\nObservation 4: The High Plains are a subregion of the Great Plains. From east to west, the\\nHigh Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130\\nm).[3]\\nThought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer\\nis 1,800 to 7,000 ft.\\nAction 5: Finish[1,800 to 7,000 ft]'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# print initial example for visibility\n", + "EXAMPLES[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a7bd36bc", + "metadata": {}, + "outputs": [], + "source": [ + "new_example = generate_example(EXAMPLES, OpenAI())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e1efb008", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['',\n", + " '',\n", + " 'Question: Is the Mount Everest taller than the Mount Kilimanjaro?',\n", + " '',\n", + " 'Thought 1: I need to search Mount Everest and Mount Kilimanjaro, find their',\n", + " 'heights, then compare them.',\n", + " '',\n", + " 'Action 1: Search[Mount Everest]',\n", + " '',\n", + " \"Observation 1: Mount Everest, at 8,848 metres (29,029 ft), is the world's highest mountain\",\n", + " 'and a particularly popular goal for mountaineers.',\n", + " '',\n", + " 'Thought 2: Mount Everest is 8,848 metres tall. I need to search Mount Kilimanjaro',\n", + " 'next.',\n", + " '',\n", + " 'Action 2: Search[Mount Kilimanjaro]',\n", + " '',\n", + " 'Observation 2: Mount Kilimanjaro, with its three volcanic cones, Kibo, Mawenzi, and',\n", + " 'Shira, is a freestanding mountain in Tanzania. It is the highest mountain in',\n", + " 'Africa, and rises approximately 4,900 metres (16,100 ft) from its base to 5,895',\n", + " 'metres (19,341 ft) above sea level.',\n", + " '',\n", + " 'Thought 3: Mount Kilimanjaro is 5,895 metres tall. 8,848 metres (Mount Everest) >',\n", + " '5,895 metres (Mount Kil']" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_example.split('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8843d7b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 66e4fe03..84768170 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -18,7 +18,7 @@ from langchain.chains import ( from langchain.docstore import Wikipedia from langchain.faiss import FAISS from langchain.llms import Cohere, HuggingFaceHub, OpenAI -from langchain.prompt import BasePrompt, DynamicPrompt, Prompt +from langchain.prompts import BasePrompt, DynamicPrompt, Prompt from langchain.sql_database import SQLDatabase __all__ = [ diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 9dadf9ef..c73db42e 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.llms.base import LLM -from langchain.prompt import BasePrompt +from langchain.prompts.base import BasePrompt class LLMChain(Chain, BaseModel): diff --git a/langchain/chains/llm_math/prompt.py b/langchain/chains/llm_math/prompt.py index a5614b3a..b389e917 100644 --- a/langchain/chains/llm_math/prompt.py +++ b/langchain/chains/llm_math/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt _PROMPT_TEMPLATE = """You are GPT-3, and you can't do math. diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 5ba2a1dc..26674442 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.llms.base import LLM -from langchain.prompt import BasePrompt +from langchain.prompts.base import BasePrompt from langchain.text_splitter import TextSplitter diff --git a/langchain/chains/natbot/prompt.py b/langchain/chains/natbot/prompt.py index 390a532d..f67775b0 100644 --- a/langchain/chains/natbot/prompt.py +++ b/langchain/chains/natbot/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt _PROMPT_TEMPLATE = """ You are an agent controlling a browser. You are given: @@ -30,7 +30,7 @@ Based on your given objective, issue whatever command you believe will get you c You always start on Google; you should submit a search query to Google that will take you to the best page for achieving your objective. And then interact with that page to achieve your objective. -If you find yourself on Google and there are no search results displayed yet, you should probably issue a command +If you find yourself on Google and there are no search results displayed yet, you should probably issue a command like "TYPESUBMIT 7 "search query"" to get to a more useful page. Then, if you find yourself on a Google search results page, you might issue the command "CLICK 24" to click @@ -66,7 +66,7 @@ CURRENT BROWSER CONTENT: ------------------ OBJECTIVE: Find a 2 bedroom house for sale in Anchorage AK for under $750k CURRENT URL: https://www.google.com/ -YOUR COMMAND: +YOUR COMMAND: TYPESUBMIT 8 "anchorage redfin" ================================================== @@ -95,7 +95,7 @@ CURRENT BROWSER CONTENT: ------------------ OBJECTIVE: Make a reservation for 4 at Dorsia at 8pm CURRENT URL: https://www.google.com/ -YOUR COMMAND: +YOUR COMMAND: TYPESUBMIT 8 "dorsia nyc opentable" ================================================== @@ -114,15 +114,15 @@ CURRENT BROWSER CONTENT: Sep 28, 2022 7:00 PM 2 people - + -It looks like you're in Peninsula. Not correct? +It looks like you're in Peninsula. Not correct? ------------------ OBJECTIVE: Make a reservation for 4 for dinner at Dorsia in New York City at 8pm CURRENT URL: https://www.opentable.com/ -YOUR COMMAND: +YOUR COMMAND: TYPESUBMIT 12 "dorsia new york city" ================================================== diff --git a/langchain/chains/react/prompt.py b/langchain/chains/react/prompt.py index 486d652c..e0e16299 100644 --- a/langchain/chains/react/prompt.py +++ b/langchain/chains/react/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt EXAMPLES = [ """Question: What is the elevation range for the area that the eastern sector of the diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/chains/self_ask_with_search/prompt.py index cb52d3c8..003e68dd 100644 --- a/langchain/chains/self_ask_with_search/prompt.py +++ b/langchain/chains/self_ask_with_search/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt _DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing? Are follow up questions needed here: Yes. diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 36d48d74..c35c92e4 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -1,7 +1,7 @@ # flake8: noqa -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt -_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" diff --git a/langchain/example_generator.py b/langchain/example_generator.py new file mode 100644 index 00000000..818a848a --- /dev/null +++ b/langchain/example_generator.py @@ -0,0 +1,20 @@ +"""Utility functions for working with prompts.""" +from typing import List + +from langchain.chains.llm import LLMChain +from langchain.llms.base import LLM +from langchain.prompts.dynamic import DynamicPrompt + +TEST_GEN_TEMPLATE_SUFFIX = "Add another example." + + +def generate_example(examples: List[str], llm: LLM) -> str: + """Return another example given a list of examples for a prompt.""" + prompt = DynamicPrompt(examples=examples, suffix=TEST_GEN_TEMPLATE_SUFFIX) + chain = LLMChain(llm=llm, prompt=prompt) + return chain.predict() + + +def generate_example_from_dynamic_prompt(prompt: DynamicPrompt, llm: LLM) -> str: + """Return another example given a DynamicPrompt object.""" + return generate_example(prompt.examples, llm) diff --git a/langchain/prompt.py b/langchain/prompt.py deleted file mode 100644 index ac63eb16..00000000 --- a/langchain/prompt.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Prompt schema definition.""" -import re -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List - -from pydantic import BaseModel, Extra, root_validator - -from langchain.formatting import formatter - -_FORMATTER_MAPPING = { - "f-string": formatter.format, -} - - -class BasePrompt(ABC): - """Base prompt should expose the format method, returning a prompt.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - @abstractmethod - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - - -class Prompt(BaseModel, BasePrompt): - """Schema to represent a prompt for an LLM. - - Example: - .. code-block:: python - - from langchain import Prompt - prompt = Prompt(input_variables=["foo"], template="Say {foo}") - """ - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - template: str - """The prompt template.""" - - template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string'.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - return _FORMATTER_MAPPING[self.template_format](self.template, **kwargs) - - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that template and input variables are consistent.""" - input_variables = values["input_variables"] - template = values["template"] - template_format = values["template_format"] - if template_format not in _FORMATTER_MAPPING: - valid_formats = list(_FORMATTER_MAPPING) - raise ValueError( - f"Invalid template format. Got `{template_format}`;" - f" should be one of {valid_formats}" - ) - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} - try: - formatter_func = _FORMATTER_MAPPING[template_format] - formatter_func(template, **dummy_inputs) - except KeyError: - raise ValueError("Invalid prompt schema.") - return values - - @classmethod - def from_examples( - cls, - examples: List[str], - suffix: str, - input_variables: List[str], - example_separator: str = "\n\n", - prefix: str = "", - ) -> "Prompt": - """Take examples in list format with prefix and suffix to create a prompt. - - Intended be used as a way to dynamically create a prompt from examples. - - Args: - examples: List of examples to use in the prompt. - suffix: String to go after the list of examples. Should generally - set up the user's input. - input_variables: A list of variable names the final prompt template - will expect. - example_separator: The seperator to use in between examples. Defaults - to two new line characters. - prefix: String that should go before any examples. Generally includes - examples. Default to an empty string. - - Returns: - The final prompt generated. - """ - example_str = example_separator.join(examples) - template = prefix + example_str + suffix - return cls(input_variables=input_variables, template=template) - - -class DynamicPrompt(BaseModel, BasePrompt): - r"""Schema to represent a dynamic prompt for an LLM. - - Example: - .. code-block:: python - - from langchain import DynamicPrompt - dynamic_prompt = DynamicPrompt( - examples=["Say hi. Hi", "Say ho. Ho"], - example_separator="\n\n", - prefix="", - suffix="\n\nSay {foo}" - input_variables=["foo"], - max_length=200, - get_text_length=word_count - ) - """ - - examples: List[str] - """A list of the examples that the prompt template expects.""" - - example_separator: str = "\n\n" - """Example separator, e.g. \n\n, for the dynamic prompt creation.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - prefix: str - """Prefix for the prompt.""" - - suffix: str - """Suffix for the prompt.""" - - template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string'.""" - - get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x)) - """Function to measure prompt length. Defaults to word count.""" - - max_length: int = 2048 - """Max length for the prompt, beyond which examples are cut.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - def template(self, example_list: List[str], **kwargs: Any) -> str: - """Return template given example list.""" - template = self.example_separator.join( - [self.prefix, *example_list, self.suffix] - ) - return _FORMATTER_MAPPING[self.template_format](template, **kwargs) - - def format(self, **kwargs: Any) -> str: - """Dynamically format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - curr_examples = self.examples - template = self.template(curr_examples, **kwargs) - while self.get_text_length(template) > self.max_length and curr_examples: - curr_examples = curr_examples[:-1] - template = self.template(curr_examples, **kwargs) - return template - - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that prefix, suffix and input variables are consistent.""" - input_variables = values["input_variables"] - suffix = values["suffix"] - template_format = values["template_format"] - if template_format not in _FORMATTER_MAPPING: - valid_formats = list(_FORMATTER_MAPPING) - raise ValueError( - f"Invalid template format. Got `{template_format}`;" - f" should be one of {valid_formats}" - ) - try: - result = values["get_text_length"]("foo") - assert isinstance(result, int) - except AssertionError: - raise ValueError( - "Invalid text length callable, must take string & return int;" - ) - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} - # TODO variables could be in prefix or suffix - try: - formatter_func = _FORMATTER_MAPPING[template_format] - formatter_func(suffix, **dummy_inputs) - except KeyError: - raise ValueError("Invalid prompt schema.") - return values diff --git a/langchain/prompts/__init__.py b/langchain/prompts/__init__.py new file mode 100644 index 00000000..177aa155 --- /dev/null +++ b/langchain/prompts/__init__.py @@ -0,0 +1,6 @@ +"""Prompt template classes.""" +from langchain.prompts.base import BasePrompt +from langchain.prompts.dynamic import DynamicPrompt +from langchain.prompts.prompt import Prompt + +__all__ = ["BasePrompt", "Prompt", "DynamicPrompt"] diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py new file mode 100644 index 00000000..d99d9400 --- /dev/null +++ b/langchain/prompts/base.py @@ -0,0 +1,33 @@ +"""BasePrompt schema definition.""" +from abc import ABC, abstractmethod +from typing import Any, List + +from langchain.formatting import formatter + +DEFAULT_FORMATTER_MAPPING = { + "f-string": formatter.format, +} + + +class BasePrompt(ABC): + """Base prompt should expose the format method, returning a prompt.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + @abstractmethod + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ diff --git a/langchain/prompts/dynamic.py b/langchain/prompts/dynamic.py new file mode 100644 index 00000000..fbf0c351 --- /dev/null +++ b/langchain/prompts/dynamic.py @@ -0,0 +1,112 @@ +"""Dynamic prompt schema definition.""" +import re +from typing import Any, Callable, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, BasePrompt + + +class DynamicPrompt(BaseModel, BasePrompt): + r"""Schema to represent a dynamic prompt for an LLM. + + Example: + .. code-block:: python + + from langchain import DynamicPrompt + dynamic_prompt = DynamicPrompt( + examples=["Say hi. Hi", "Say ho. Ho"], + example_separator="\n\n", + prefix="", + suffix="\n\nSay {foo}" + input_variables=["foo"], + max_length=200, + get_text_length=word_count + ) + """ + + examples: List[str] + """A list of the examples that the prompt template expects.""" + + example_separator: str = "\n\n" + """Example separator, e.g. \n\n, for the dynamic prompt creation.""" + + input_variables: List[str] = [] + """A list of the names of the variables the prompt template expects.""" + + prefix: str = "" + """Prefix for the prompt.""" + + suffix: str = "" + """Suffix for the prompt.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string'.""" + + get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x)) + """Function to measure prompt length. Defaults to word count.""" + + max_length: int = 2048 + """Max length for the prompt, beyond which examples are cut.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def template(self, example_list: List[str], **kwargs: Any) -> str: + """Return template given example list.""" + template = self.example_separator.join( + [self.prefix, *example_list, self.suffix] + ) + return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + + def format(self, **kwargs: Any) -> str: + """Dynamically format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + curr_examples = self.examples + template = self.template(curr_examples, **kwargs) + while self.get_text_length(template) > self.max_length and curr_examples: + curr_examples = curr_examples[:-1] + template = self.template(curr_examples, **kwargs) + return template + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that prefix, suffix and input variables are consistent.""" + input_variables = values["input_variables"] + prefix = values["prefix"] + suffix = values["suffix"] + template_format = values["template_format"] + if template_format not in DEFAULT_FORMATTER_MAPPING: + valid_formats = list(DEFAULT_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + try: + result = values["get_text_length"]("foo") + assert isinstance(result, int) + except AssertionError: + raise ValueError( + "Invalid text length callable, must take string & return int;" + ) + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + try: + formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] + formatter_func(prefix + suffix, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py new file mode 100644 index 00000000..02f87b77 --- /dev/null +++ b/langchain/prompts/prompt.py @@ -0,0 +1,99 @@ +"""Prompt schema definition.""" +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, BasePrompt + + +class Prompt(BaseModel, BasePrompt): + """Schema to represent a prompt for an LLM. + + Example: + .. code-block:: python + + from langchain import Prompt + prompt = Prompt(input_variables=["foo"], template="Say {foo}") + """ + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + template: str + """The prompt template.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string'.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + input_variables = values["input_variables"] + template = values["template"] + template_format = values["template_format"] + if template_format not in DEFAULT_FORMATTER_MAPPING: + valid_formats = list(DEFAULT_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + try: + formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] + formatter_func(template, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values + + @classmethod + def from_examples( + cls, + examples: List[str], + suffix: str, + input_variables: List[str], + example_separator: str = "\n\n", + prefix: str = "", + ) -> "Prompt": + """Take examples in list format with prefix and suffix to create a prompt. + + Intended be used as a way to dynamically create a prompt from examples. + + Args: + examples: List of examples to use in the prompt. + suffix: String to go after the list of examples. Should generally + set up the user's input. + input_variables: A list of variable names the final prompt template + will expect. + example_separator: The seperator to use in between examples. Defaults + to two new line characters. + prefix: String that should go before any examples. Generally includes + examples. Default to an empty string. + + Returns: + The final prompt generated. + """ + example_str = example_separator.join(examples) + template = prefix + example_str + suffix + return cls(input_variables=input_variables, template=template) diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index 4c350637..0077df86 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -2,7 +2,7 @@ import pytest from langchain.chains.llm import LLMChain -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/tests/unit_tests/chains/test_react.py b/tests/unit_tests/chains/test_react.py index 7ca3cd84..e5c22dd4 100644 --- a/tests/unit_tests/chains/test_react.py +++ b/tests/unit_tests/chains/test_react.py @@ -9,7 +9,7 @@ from langchain.chains.react.base import ReActChain, predict_until_observation from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt _PAGE_CONTENT = """This is a page about LangChain. diff --git a/tests/unit_tests/test_dynamic_prompt.py b/tests/unit_tests/test_dynamic_prompt.py index e4b4443b..72f56eea 100644 --- a/tests/unit_tests/test_dynamic_prompt.py +++ b/tests/unit_tests/test_dynamic_prompt.py @@ -1,5 +1,6 @@ """Test functionality related to dynamic prompts.""" -from langchain.prompt import DynamicPrompt, Prompt +from langchain.prompts.dynamic import DynamicPrompt +from langchain.prompts.prompt import Prompt # FULL TEMPLATES LONGER_TEMPLATE = """Test Prompt: diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 280620cf..45040793 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -1,7 +1,7 @@ """Test functionality related to prompts.""" import pytest -from langchain.prompt import Prompt +from langchain.prompts.prompt import Prompt def test_prompt_valid() -> None: