From a0780cc930bed8e23b78090388529c904800addd Mon Sep 17 00:00:00 2001 From: Samantha Whitmore Date: Wed, 9 Nov 2022 21:15:42 -0800 Subject: [PATCH] OptimizedPrompt -- k-shot example choice backed by semantic search (#91) --- examples/prompt_optimization.py.ipynb | 199 ++++++++++++++++++ langchain/prompts/optimized.py | 171 +++++++++++++++ .../test_nlp_text_splitters.py | 4 +- 3 files changed, 371 insertions(+), 3 deletions(-) create mode 100644 examples/prompt_optimization.py.ipynb create mode 100644 langchain/prompts/optimized.py diff --git a/examples/prompt_optimization.py.ipynb b/examples/prompt_optimization.py.ipynb new file mode 100644 index 0000000000..2c8d135829 --- /dev/null +++ b/examples/prompt_optimization.py.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e9e2b50b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.react.prompt import EXAMPLES, SUFFIX\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.example_generator import generate_example, generate_example_from_dynamic_prompt\n", + "from langchain.llms.openai import OpenAI\n", + "from langchain.prompts.optimized import OptimizedPrompt\n", + "from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch\n", + "from langchain.vectorstores.faiss import FAISS" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cb069606", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Question: What is the elevation range for the area that the eastern sector of the\\nColorado orogeny extends into?\\nThought 1: I need to search Colorado orogeny, find the area that the eastern sector\\nof the Colorado orogeny extends into, then find the elevation range of the\\narea.\\nAction 1: Search[Colorado orogeny]\\nObservation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in\\nColorado and surrounding areas.\\nThought 2: It does not mention the eastern sector. So I need to look up eastern\\nsector.\\nAction 2: Lookup[eastern sector]\\nObservation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called\\nthe Central Plains orogeny.\\nThought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I\\nneed to search High Plains and find its elevation range.\\nAction 3: Search[High Plains]\\nObservation 3: High Plains refers to one of two distinct land regions\\nThought 4: I need to instead search High Plains (United States).\\nAction 4: Search[High Plains (United States)]\\nObservation 4: The High Plains are a subregion of the Great Plains. From east to west, the\\nHigh Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130\\nm).[3]\\nThought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer\\nis 1,800 to 7,000 ft.\\nAction 5: Finish[1,800 to 7,000 ft]'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "EXAMPLES[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5fda75a4", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = OptimizedPrompt.from_examples(\n", + " examples=EXAMPLES, \n", + " suffix=SUFFIX, \n", + " input_variables=[\"input\"],\n", + " embeddings=OpenAIEmbeddings(),\n", + " vectorstore_cls=FAISS\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7a601df8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Question: What is the elevation range for the area that the eastern sector of the\n", + "Colorado orogeny extends into?\n", + "Thought 1: I need to search Colorado orogeny, find the area that the eastern sector\n", + "of the Colorado orogeny extends into, then find the elevation range of the\n", + "area.\n", + "Action 1: Search[Colorado orogeny]\n", + "Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in\n", + "Colorado and surrounding areas.\n", + "Thought 2: It does not mention the eastern sector. So I need to look up eastern\n", + "sector.\n", + "Action 2: Lookup[eastern sector]\n", + "Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called\n", + "the Central Plains orogeny.\n", + "Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I\n", + "need to search High Plains and find its elevation range.\n", + "Action 3: Search[High Plains]\n", + "Observation 3: High Plains refers to one of two distinct land regions\n", + "Thought 4: I need to instead search High Plains (United States).\n", + "Action 4: Search[High Plains (United States)]\n", + "Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the\n", + "High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130\n", + "m).[3]\n", + "Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer\n", + "is 1,800 to 7,000 ft.\n", + "Action 5: Finish[1,800 to 7,000 ft]\n", + "\n", + "\n", + "\n", + "Question: What is the highest mountain peak in Asia?\n" + ] + } + ], + "source": [ + "print(prompt.format(k=1, input=\"What is the highest mountain peak in Asia?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f7f06820", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = OptimizedPrompt.from_examples(\n", + " examples=EXAMPLES, \n", + " suffix=SUFFIX, \n", + " input_variables=[\"input\"],\n", + " embeddings=OpenAIEmbeddings(),\n", + " vectorstore_cls=ElasticVectorSearch,\n", + " elasticsearch_url=\"http://localhost:9200\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bd91f408", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Question: What is the elevation range for the area that the eastern sector of the\n", + "Colorado orogeny extends into?\n", + "Thought 1: I need to search Colorado orogeny, find the area that the eastern sector\n", + "of the Colorado orogeny extends into, then find the elevation range of the\n", + "area.\n", + "Action 1: Search[Colorado orogeny]\n", + "Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in\n", + "Colorado and surrounding areas.\n", + "Thought 2: It does not mention the eastern sector. So I need to look up eastern\n", + "sector.\n", + "Action 2: Lookup[eastern sector]\n", + "Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called\n", + "the Central Plains orogeny.\n", + "Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I\n", + "need to search High Plains and find its elevation range.\n", + "Action 3: Search[High Plains]\n", + "Observation 3: High Plains refers to one of two distinct land regions\n", + "Thought 4: I need to instead search High Plains (United States).\n", + "Action 4: Search[High Plains (United States)]\n", + "Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the\n", + "High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130\n", + "m).[3]\n", + "Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer\n", + "is 1,800 to 7,000 ft.\n", + "Action 5: Finish[1,800 to 7,000 ft]\n", + "\n", + "\n", + "\n", + "Question: What is the highest mountain peak in Asia?\n" + ] + } + ], + "source": [ + "print(prompt.format(k=1, input=\"What is the highest mountain peak in Asia?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "716165c2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/prompts/optimized.py b/langchain/prompts/optimized.py new file mode 100644 index 0000000000..56b5dd4e8d --- /dev/null +++ b/langchain/prompts/optimized.py @@ -0,0 +1,171 @@ +"""Optimized prompt schema definition.""" +import re +from typing import Any, Callable, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING +from langchain.vectorstores.base import VectorStore + + +class OptimizedPrompt(BaseModel): + r"""Schema to represent an optimized prompt for an LLM. + + Example: + .. code-block:: python + + from langchain import DynamicPrompt + vectorstore = FAISS.from_texts(examples, OpenAIEmbeddings() + optimized_prompt = OptimizedPrompt( + examples=["Say hi. Hi", "Say ho. Ho"], + example_separator="\n\n", + prefix="", + suffix="\n\nSay {foo}" + input_variables=["foo"], + max_length=200, + get_text_length=word_count, + vectorstore=vectorstore) + ) + """ + + examples: List[str] + """A list of the examples that the prompt template expects.""" + + example_separator: str = "\n\n" + """Example separator, e.g. \n\n, for the dynamic prompt creation.""" + + input_variables: List[str] = [] + """A list of the names of the variables the prompt template expects.""" + + prefix: str = "" + """Prefix for the prompt.""" + + suffix: str = "" + """Suffix for the prompt.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string'.""" + + get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x)) + """Function to measure prompt length. Defaults to word count.""" + + max_length: int = 2048 + """Max length for the prompt, beyond which examples are cut.""" + + vectorstore: VectorStore + """Vectorstore to use for storing the embeddings.""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + extra = Extra.forbid + + def template(self, example_list: List[str], **kwargs: Any) -> str: + """Return template given full example list.""" + template = self.example_separator.join( + [self.prefix, *example_list, self.suffix] + ) + return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + + def format(self, k: int = 4, **kwargs: Any) -> str: + """Optimize the examples in the prompt for the given inputs. + + Args: + k: Number of examples to aim for (may be trimmed by optimizer afterwards) + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + query = " ".join([v for k, v in kwargs.items()]) + example_docs = self.vectorstore.similarity_search(query, k=k) + curr_examples = [str(e.page_content) for e in example_docs] + template = self.template(curr_examples, **kwargs) + while self.get_text_length(template) > self.max_length and curr_examples: + curr_examples = curr_examples[:-1] + template = self.template(curr_examples, **kwargs) + return template + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that prefix, suffix and input variables are consistent.""" + input_variables = values["input_variables"] + if len(input_variables) > 1: + raise ValueError("Only one input variable allowed for optimized prompt;") + prefix = values["prefix"] + suffix = values["suffix"] + template_format = values["template_format"] + if template_format not in DEFAULT_FORMATTER_MAPPING: + valid_formats = list(DEFAULT_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + try: + result = values["get_text_length"]("foo") + assert isinstance(result, int) + except AssertionError: + raise ValueError( + "Invalid text length callable, must take string & return int;" + ) + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + try: + formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] + formatter_func(prefix + suffix, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values + + @classmethod + def from_examples( + cls, + examples: List[str], + suffix: str, + input_variables: List[str], + embeddings: Embeddings, + vectorstore_cls: VectorStore, + example_separator: str = "\n\n", + prefix: str = "", + **vectorstore_cls_kwargs: Any, + ) -> "OptimizedPrompt": + """Create k-shot prompt optimizer using example list and embeddings. + + Reshuffles examples for the prompt dynamically based on query similarity. + + Args: + examples: List of examples to use in the prompt. + suffix: String to go after the list of examples. Should generally + set up the user's input. + input_variables: A list of variable names the final prompt template + will expect. + embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). + vectorstore_cls: A vector store DB interface class, e.g. FAISS. + example_separator: The seperator to use in between examples. Defaults + to two new line characters. + prefix: String that should go before any examples. Generally includes + examples. Default to an empty string. + vectorstore_cls_kwargs: optional kwargs containing url for vector store + + Returns: + The OptimizedPrompt instantiated, backed by a vector store. + """ + vectorstore = vectorstore_cls.from_texts( + examples, embeddings, **vectorstore_cls_kwargs + ) + return cls( + examples=examples, + suffix=suffix, + input_variables=input_variables, + example_separator=example_separator, + prefix=prefix, + vectorstore=vectorstore, + ) diff --git a/tests/integration_tests/test_nlp_text_splitters.py b/tests/integration_tests/test_nlp_text_splitters.py index 734092275e..4837fe20ad 100644 --- a/tests/integration_tests/test_nlp_text_splitters.py +++ b/tests/integration_tests/test_nlp_text_splitters.py @@ -1,6 +1,4 @@ -""" -Test text splitting functionality using NLTK and Spacy based sentence splitters. -""" +"""Test text splitting functionality using NLTK and Spacy based sentence splitters.""" import pytest from langchain.text_splitter import NLTKTextSplitter, SpacyTextSplitter