From 23d5f64bda5e2973a6afb99b811c02eaf9a9d7ea Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 2 Feb 2023 09:44:42 -0800 Subject: [PATCH] Harrison/ngram example (#846) Co-authored-by: Sean Spriggens --- .../prompts/examples/example_selectors.ipynb | 238 +++++++++++++++++- .../prompts/example_selector/ngram_overlap.py | 112 +++++++++ .../test_ngram_overlap_example_selector.py | 73 ++++++ 3 files changed, 420 insertions(+), 3 deletions(-) create mode 100644 langchain/prompts/example_selector/ngram_overlap.py create mode 100644 tests/integration_tests/test_ngram_overlap_example_selector.py diff --git a/docs/modules/prompts/examples/example_selectors.ipynb b/docs/modules/prompts/examples/example_selectors.ipynb index f755920d..a0ffe4a4 100644 --- a/docs/modules/prompts/examples/example_selectors.ipynb +++ b/docs/modules/prompts/examples/example_selectors.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "8244ff60", "metadata": {}, "outputs": [], @@ -81,7 +81,7 @@ " template=\"Input: {input}\\nOutput: {output}\",\n", ")\n", "example_selector = LengthBasedExampleSelector(\n", - " # These are the examples is has available to choose from.\n", + " # These are the examples it has available to choose from.\n", " examples=examples, \n", " # This is the PromptTemplate being used to format the examples.\n", " example_prompt=example_prompt, \n", @@ -439,10 +439,242 @@ "print(similar_prompt.format(adjective=\"worried\"))" ] }, + { + "cell_type": "markdown", + "id": "4aaeed2f", + "metadata": {}, + "source": [ + "## NGram Overlap ExampleSelector\n", + "\n", + "The NGramOverlapExampleSelector selects and orders examples based on which examples are most similar to the input, according to an ngram overlap score. The ngram overlap score is a float between 0.0 and 1.0, inclusive. \n", + "\n", + "The selector allows for a threshold score to be set. Examples with an ngram overlap score less than or equal to the threshold are excluded. The threshold is set to -1.0, by default, so will not exclude any examples, only reorder them. Setting the threshold to 0.0 will exclude examples that have no ngram overlaps with the input.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9cbc0acc", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4f318f4b", + "metadata": {}, + "outputs": [], + "source": [ + "# These are examples of a fictional translation task.\n", + "examples = [\n", + " {\"input\": \"See Spot run.\", \"output\": \"Ver correr a Spot.\"},\n", + " {\"input\": \"My dog barks.\", \"output\": \"Mi perro ladra.\"},\n", + " {\"input\": \"Spot can run.\", \"output\": \"Spot puede correr.\"},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bf75e0fe", + "metadata": {}, + "outputs": [], + "source": [ + "example_prompt = PromptTemplate(\n", + " input_variables=[\"input\", \"output\"],\n", + " template=\"Input: {input}\\nOutput: {output}\",\n", + ")\n", + "example_selector = NGramOverlapExampleSelector(\n", + " # These are the examples it has available to choose from.\n", + " examples=examples, \n", + " # This is the PromptTemplate being used to format the examples.\n", + " example_prompt=example_prompt, \n", + " # This is the threshold, at which selector stops.\n", + " # It is set to -1.0 by default.\n", + " threshold=-1.0,\n", + " # For negative threshold:\n", + " # Selector sorts examples by ngram overlap score, and excludes none.\n", + " # For threshold greater than 1.0:\n", + " # Selector excludes all examples, and returns an empty list.\n", + " # For threshold equal to 0.0:\n", + " # Selector sorts examples by ngram overlap score,\n", + " # and excludes those with no ngram overlap with input.\n", + ")\n", + "dynamic_prompt = FewShotPromptTemplate(\n", + " # We provide an ExampleSelector instead of examples.\n", + " example_selector=example_selector,\n", + " example_prompt=example_prompt,\n", + " prefix=\"Give the Spanish translation of every input\",\n", + " suffix=\"Input: {sentence}\\nOutput:\", \n", + " input_variables=[\"sentence\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "83fb218a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the Spanish translation of every input\n", + "\n", + "Input: Spot can run.\n", + "Output: Spot puede correr.\n", + "\n", + "Input: See Spot run.\n", + "Output: Ver correr a Spot.\n", + "\n", + "Input: My dog barks.\n", + "Output: Mi perro ladra.\n", + "\n", + "Input: Spot can run fast.\n", + "Output:\n" + ] + } + ], + "source": [ + "# An example input with large ngram overlap with \"Spot can run.\"\n", + "# and no overlap with \"My dog barks.\"\n", + "print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "485f5307", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the Spanish translation of every input\n", + "\n", + "Input: Spot can run.\n", + "Output: Spot puede correr.\n", + "\n", + "Input: See Spot run.\n", + "Output: Ver correr a Spot.\n", + "\n", + "Input: Spot plays fetch.\n", + "Output: Spot juega a buscar.\n", + "\n", + "Input: My dog barks.\n", + "Output: Mi perro ladra.\n", + "\n", + "Input: Spot can run fast.\n", + "Output:\n" + ] + } + ], + "source": [ + "# You can add examples to NGramOverlapExampleSelector as well.\n", + "new_example = {\"input\": \"Spot plays fetch.\", \"output\": \"Spot juega a buscar.\"}\n", + "\n", + "example_selector.add_example(new_example)\n", + "print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "606ce697", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the Spanish translation of every input\n", + "\n", + "Input: Spot can run.\n", + "Output: Spot puede correr.\n", + "\n", + "Input: See Spot run.\n", + "Output: Ver correr a Spot.\n", + "\n", + "Input: Spot plays fetch.\n", + "Output: Spot juega a buscar.\n", + "\n", + "Input: Spot can run fast.\n", + "Output:\n" + ] + } + ], + "source": [ + "# You can set a threshold at which examples are excluded.\n", + "# For example, setting threshold equal to 0.0\n", + "# excludes examples with no ngram overlaps with input.\n", + "# Since \"My dog barks.\" has no ngram overlaps with \"Spot can run fast.\"\n", + "# it is excluded.\n", + "example_selector.threshold=0.0\n", + "print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "7f8d72f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the Spanish translation of every input\n", + "\n", + "Input: Spot can run.\n", + "Output: Spot puede correr.\n", + "\n", + "Input: Spot plays fetch.\n", + "Output: Spot juega a buscar.\n", + "\n", + "Input: Spot can play fetch.\n", + "Output:\n" + ] + } + ], + "source": [ + "# Setting small nonzero threshold\n", + "example_selector.threshold=0.09\n", + "print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "09633aa8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the Spanish translation of every input\n", + "\n", + "Input: Spot can play fetch.\n", + "Output:\n" + ] + } + ], + "source": [ + "# Setting threshold greater than 1.0\n", + "example_selector.threshold=1.0+1e-9\n", + "print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "c746d6f4", + "id": "39f30097", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/prompts/example_selector/ngram_overlap.py b/langchain/prompts/example_selector/ngram_overlap.py new file mode 100644 index 00000000..335331ec --- /dev/null +++ b/langchain/prompts/example_selector/ngram_overlap.py @@ -0,0 +1,112 @@ +"""Select and order examples based on ngram overlap score (sentence_bleu score). + +https://www.nltk.org/_modules/nltk/translate/bleu_score.html +https://aclanthology.org/P02-1040.pdf +""" +from typing import Dict, List + +import numpy as np +from pydantic import BaseModel, root_validator + +from langchain.prompts.example_selector.base import BaseExampleSelector +from langchain.prompts.prompt import PromptTemplate + + +def ngram_overlap_score(source: List[str], example: List[str]) -> float: + """Compute ngram overlap score of source and example as sentence_bleu score. + + Use sentence_bleu with method1 smoothing function and auto reweighting. + Return float value between 0.0 and 1.0 inclusive. + https://www.nltk.org/_modules/nltk/translate/bleu_score.html + https://aclanthology.org/P02-1040.pdf + """ + from nltk.translate.bleu_score import ( # type: ignore + SmoothingFunction, + sentence_bleu, + ) + + hypotheses = source[0].split() + references = [s.split() for s in example] + + return float( + sentence_bleu( + references, + hypotheses, + smoothing_function=SmoothingFunction().method1, + auto_reweigh=True, + ) + ) + + +class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): + """Select and order examples based on ngram overlap score (sentence_bleu score). + + https://www.nltk.org/_modules/nltk/translate/bleu_score.html + https://aclanthology.org/P02-1040.pdf + """ + + examples: List[dict] + """A list of the examples that the prompt template expects.""" + + example_prompt: PromptTemplate + """Prompt template used to format the examples.""" + + threshold: float = -1.0 + """Threshold at which algorithm stops. Set to -1.0 by default. + + For negative threshold: + select_examples sorts examples by ngram_overlap_score, but excludes none. + For threshold greater than 1.0: + select_examples excludes all examples, and returns an empty list. + For threshold equal to 0.0: + select_examples sorts examples by ngram_overlap_score, + and excludes examples with no ngram overlap with input. + """ + + @root_validator(pre=True) + def check_dependencies(cls, values: Dict) -> Dict: + """Check that valid dependencies exist.""" + try: + from nltk.translate.bleu_score import ( # noqa: disable=F401 + SmoothingFunction, + sentence_bleu, + ) + except ImportError as e: + raise ValueError( + "Not all the correct dependencies for this ExampleSelect exist" + ) from e + + return values + + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to list.""" + self.examples.append(example) + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Return list of examples sorted by ngram_overlap_score with input. + + Descending order. + Excludes any examples with ngram_overlap_score less than or equal to threshold. + """ + inputs = list(input_variables.values()) + examples = [] + k = len(self.examples) + score = [0.0] * k + first_prompt_template_key = self.example_prompt.input_variables[0] + + for i in range(k): + score[i] = ngram_overlap_score( + inputs, [self.examples[i][first_prompt_template_key]] + ) + + while True: + arg_max = np.argmax(score) + if (score[arg_max] < self.threshold) or abs( + score[arg_max] - self.threshold + ) < 1e-9: + break + + examples.append(self.examples[arg_max]) + score[arg_max] = self.threshold - 1.0 + + return examples diff --git a/tests/integration_tests/test_ngram_overlap_example_selector.py b/tests/integration_tests/test_ngram_overlap_example_selector.py new file mode 100644 index 00000000..5c7bd4b1 --- /dev/null +++ b/tests/integration_tests/test_ngram_overlap_example_selector.py @@ -0,0 +1,73 @@ +"""Test functionality related to ngram overlap based selector.""" + +import pytest + +from langchain.prompts.example_selector.ngram_overlap import ( + NGramOverlapExampleSelector, + ngram_overlap_score, +) +from langchain.prompts.prompt import PromptTemplate + +EXAMPLES = [ + {"input": "See Spot run.", "output": "foo1"}, + {"input": "My dog barks.", "output": "foo2"}, + {"input": "Spot can run.", "output": "foo3"}, +] + + +@pytest.fixture +def selector() -> NGramOverlapExampleSelector: + """Get ngram overlap based selector to use in tests.""" + prompts = PromptTemplate( + input_variables=["input", "output"], template="Input: {input}\nOutput: {output}" + ) + selector = NGramOverlapExampleSelector( + examples=EXAMPLES, + example_prompt=prompts, + ) + return selector + + +def test_selector_valid(selector: NGramOverlapExampleSelector) -> None: + """Test NGramOverlapExampleSelector can select examples.""" + sentence = "Spot can run." + output = selector.select_examples({"input": sentence}) + assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]] + + +def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None: + """Test NGramOverlapExampleSelector can add an example.""" + new_example = {"input": "Spot plays fetch.", "output": "foo4"} + selector.add_example(new_example) + sentence = "Spot can run." + output = selector.select_examples({"input": sentence}) + assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]] + + +def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None: + """Tests NGramOverlapExampleSelector threshold set to 0.0.""" + selector.threshold = 0.0 + sentence = "Spot can run." + output = selector.select_examples({"input": sentence}) + assert output == [EXAMPLES[2], EXAMPLES[0]] + + +def test_selector_threshold_more_than_one( + selector: NGramOverlapExampleSelector, +) -> None: + """Tests NGramOverlapExampleSelector threshold greater than 1.0.""" + selector.threshold = 1.0 + 1e-9 + sentence = "Spot can run." + output = selector.select_examples({"input": sentence}) + assert output == [] + + +def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None: + """Tests that ngram_overlap_score returns correct values.""" + selector.threshold = 1.0 + 1e-9 + none = ngram_overlap_score(["Spot can run."], ["My dog barks."]) + some = ngram_overlap_score(["Spot can run."], ["See Spot run."]) + complete = ngram_overlap_score(["Spot can run."], ["Spot can run."]) + + check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9] + assert check == [True, True, True]