Harrison/ngram example (#846)

Co-authored-by: Sean Spriggens <ssprigge@syr.edu>
ankush/retry-openai
Harrison Chase 1 year ago committed by GitHub
parent 0de55048b7
commit 23d5f64bda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "8244ff60",
"metadata": {},
"outputs": [],
@ -81,7 +81,7 @@
" template=\"Input: {input}\\nOutput: {output}\",\n",
")\n",
"example_selector = LengthBasedExampleSelector(\n",
" # These are the examples is has available to choose from.\n",
" # These are the examples it has available to choose from.\n",
" examples=examples, \n",
" # This is the PromptTemplate being used to format the examples.\n",
" example_prompt=example_prompt, \n",
@ -439,10 +439,242 @@
"print(similar_prompt.format(adjective=\"worried\"))"
]
},
{
"cell_type": "markdown",
"id": "4aaeed2f",
"metadata": {},
"source": [
"## NGram Overlap ExampleSelector\n",
"\n",
"The NGramOverlapExampleSelector selects and orders examples based on which examples are most similar to the input, according to an ngram overlap score. The ngram overlap score is a float between 0.0 and 1.0, inclusive. \n",
"\n",
"The selector allows for a threshold score to be set. Examples with an ngram overlap score less than or equal to the threshold are excluded. The threshold is set to -1.0, by default, so will not exclude any examples, only reorder them. Setting the threshold to 0.0 will exclude examples that have no ngram overlaps with the input.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9cbc0acc",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4f318f4b",
"metadata": {},
"outputs": [],
"source": [
"# These are examples of a fictional translation task.\n",
"examples = [\n",
" {\"input\": \"See Spot run.\", \"output\": \"Ver correr a Spot.\"},\n",
" {\"input\": \"My dog barks.\", \"output\": \"Mi perro ladra.\"},\n",
" {\"input\": \"Spot can run.\", \"output\": \"Spot puede correr.\"},\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bf75e0fe",
"metadata": {},
"outputs": [],
"source": [
"example_prompt = PromptTemplate(\n",
" input_variables=[\"input\", \"output\"],\n",
" template=\"Input: {input}\\nOutput: {output}\",\n",
")\n",
"example_selector = NGramOverlapExampleSelector(\n",
" # These are the examples it has available to choose from.\n",
" examples=examples, \n",
" # This is the PromptTemplate being used to format the examples.\n",
" example_prompt=example_prompt, \n",
" # This is the threshold, at which selector stops.\n",
" # It is set to -1.0 by default.\n",
" threshold=-1.0,\n",
" # For negative threshold:\n",
" # Selector sorts examples by ngram overlap score, and excludes none.\n",
" # For threshold greater than 1.0:\n",
" # Selector excludes all examples, and returns an empty list.\n",
" # For threshold equal to 0.0:\n",
" # Selector sorts examples by ngram overlap score,\n",
" # and excludes those with no ngram overlap with input.\n",
")\n",
"dynamic_prompt = FewShotPromptTemplate(\n",
" # We provide an ExampleSelector instead of examples.\n",
" example_selector=example_selector,\n",
" example_prompt=example_prompt,\n",
" prefix=\"Give the Spanish translation of every input\",\n",
" suffix=\"Input: {sentence}\\nOutput:\", \n",
" input_variables=[\"sentence\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "83fb218a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the Spanish translation of every input\n",
"\n",
"Input: Spot can run.\n",
"Output: Spot puede correr.\n",
"\n",
"Input: See Spot run.\n",
"Output: Ver correr a Spot.\n",
"\n",
"Input: My dog barks.\n",
"Output: Mi perro ladra.\n",
"\n",
"Input: Spot can run fast.\n",
"Output:\n"
]
}
],
"source": [
"# An example input with large ngram overlap with \"Spot can run.\"\n",
"# and no overlap with \"My dog barks.\"\n",
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "485f5307",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the Spanish translation of every input\n",
"\n",
"Input: Spot can run.\n",
"Output: Spot puede correr.\n",
"\n",
"Input: See Spot run.\n",
"Output: Ver correr a Spot.\n",
"\n",
"Input: Spot plays fetch.\n",
"Output: Spot juega a buscar.\n",
"\n",
"Input: My dog barks.\n",
"Output: Mi perro ladra.\n",
"\n",
"Input: Spot can run fast.\n",
"Output:\n"
]
}
],
"source": [
"# You can add examples to NGramOverlapExampleSelector as well.\n",
"new_example = {\"input\": \"Spot plays fetch.\", \"output\": \"Spot juega a buscar.\"}\n",
"\n",
"example_selector.add_example(new_example)\n",
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "606ce697",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the Spanish translation of every input\n",
"\n",
"Input: Spot can run.\n",
"Output: Spot puede correr.\n",
"\n",
"Input: See Spot run.\n",
"Output: Ver correr a Spot.\n",
"\n",
"Input: Spot plays fetch.\n",
"Output: Spot juega a buscar.\n",
"\n",
"Input: Spot can run fast.\n",
"Output:\n"
]
}
],
"source": [
"# You can set a threshold at which examples are excluded.\n",
"# For example, setting threshold equal to 0.0\n",
"# excludes examples with no ngram overlaps with input.\n",
"# Since \"My dog barks.\" has no ngram overlaps with \"Spot can run fast.\"\n",
"# it is excluded.\n",
"example_selector.threshold=0.0\n",
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "7f8d72f7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the Spanish translation of every input\n",
"\n",
"Input: Spot can run.\n",
"Output: Spot puede correr.\n",
"\n",
"Input: Spot plays fetch.\n",
"Output: Spot juega a buscar.\n",
"\n",
"Input: Spot can play fetch.\n",
"Output:\n"
]
}
],
"source": [
"# Setting small nonzero threshold\n",
"example_selector.threshold=0.09\n",
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "09633aa8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the Spanish translation of every input\n",
"\n",
"Input: Spot can play fetch.\n",
"Output:\n"
]
}
],
"source": [
"# Setting threshold greater than 1.0\n",
"example_selector.threshold=1.0+1e-9\n",
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c746d6f4",
"id": "39f30097",
"metadata": {},
"outputs": [],
"source": []

@ -0,0 +1,112 @@
"""Select and order examples based on ngram overlap score (sentence_bleu score).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from typing import Dict, List
import numpy as np
from pydantic import BaseModel, root_validator
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
"""Compute ngram overlap score of source and example as sentence_bleu score.
Use sentence_bleu with method1 smoothing function and auto reweighting.
Return float value between 0.0 and 1.0 inclusive.
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import ( # type: ignore
SmoothingFunction,
sentence_bleu,
)
hypotheses = source[0].split()
references = [s.split() for s in example]
return float(
sentence_bleu(
references,
hypotheses,
smoothing_function=SmoothingFunction().method1,
auto_reweigh=True,
)
)
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
"""Select and order examples based on ngram overlap score (sentence_bleu score).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
threshold: float = -1.0
"""Threshold at which algorithm stops. Set to -1.0 by default.
For negative threshold:
select_examples sorts examples by ngram_overlap_score, but excludes none.
For threshold greater than 1.0:
select_examples excludes all examples, and returns an empty list.
For threshold equal to 0.0:
select_examples sorts examples by ngram_overlap_score,
and excludes examples with no ngram overlap with input.
"""
@root_validator(pre=True)
def check_dependencies(cls, values: Dict) -> Dict:
"""Check that valid dependencies exist."""
try:
from nltk.translate.bleu_score import ( # noqa: disable=F401
SmoothingFunction,
sentence_bleu,
)
except ImportError as e:
raise ValueError(
"Not all the correct dependencies for this ExampleSelect exist"
) from e
return values
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Return list of examples sorted by ngram_overlap_score with input.
Descending order.
Excludes any examples with ngram_overlap_score less than or equal to threshold.
"""
inputs = list(input_variables.values())
examples = []
k = len(self.examples)
score = [0.0] * k
first_prompt_template_key = self.example_prompt.input_variables[0]
for i in range(k):
score[i] = ngram_overlap_score(
inputs, [self.examples[i][first_prompt_template_key]]
)
while True:
arg_max = np.argmax(score)
if (score[arg_max] < self.threshold) or abs(
score[arg_max] - self.threshold
) < 1e-9:
break
examples.append(self.examples[arg_max])
score[arg_max] = self.threshold - 1.0
return examples

@ -0,0 +1,73 @@
"""Test functionality related to ngram overlap based selector."""
import pytest
from langchain.prompts.example_selector.ngram_overlap import (
NGramOverlapExampleSelector,
ngram_overlap_score,
)
from langchain.prompts.prompt import PromptTemplate
EXAMPLES = [
{"input": "See Spot run.", "output": "foo1"},
{"input": "My dog barks.", "output": "foo2"},
{"input": "Spot can run.", "output": "foo3"},
]
@pytest.fixture
def selector() -> NGramOverlapExampleSelector:
"""Get ngram overlap based selector to use in tests."""
prompts = PromptTemplate(
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
)
selector = NGramOverlapExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
)
return selector
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can select examples."""
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can add an example."""
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
selector.add_example(new_example)
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
selector.threshold = 0.0
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]]
def test_selector_threshold_more_than_one(
selector: NGramOverlapExampleSelector,
) -> None:
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
selector.threshold = 1.0 + 1e-9
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == []
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
"""Tests that ngram_overlap_score returns correct values."""
selector.threshold = 1.0 + 1e-9
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
assert check == [True, True, True]
Loading…
Cancel
Save