diff --git a/libs/community/langchain_community/example_selectors/__init__.py b/libs/community/langchain_community/example_selectors/__init__.py new file mode 100644 index 0000000000..70654d689b --- /dev/null +++ b/libs/community/langchain_community/example_selectors/__init__.py @@ -0,0 +1,10 @@ +"""Logic for selecting examples to include in prompts.""" +from langchain_community.example_selectors.ngram_overlap import ( + NGramOverlapExampleSelector, + ngram_overlap_score, +) + +__all__ = [ + "NGramOverlapExampleSelector", + "ngram_overlap_score", +] diff --git a/libs/community/langchain_community/example_selectors/ngram_overlap.py b/libs/community/langchain_community/example_selectors/ngram_overlap.py new file mode 100644 index 0000000000..c8d662f8e1 --- /dev/null +++ b/libs/community/langchain_community/example_selectors/ngram_overlap.py @@ -0,0 +1,114 @@ +"""Select and order examples based on ngram overlap score (sentence_bleu score). + +https://www.nltk.org/_modules/nltk/translate/bleu_score.html +https://aclanthology.org/P02-1040.pdf +""" +from typing import Dict, List + +import numpy as np +from langchain_core.example_selectors import BaseExampleSelector +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, root_validator + + +def ngram_overlap_score(source: List[str], example: List[str]) -> float: + """Compute ngram overlap score of source and example as sentence_bleu score + from NLTK package. + + Use sentence_bleu with method1 smoothing function and auto reweighting. + Return float value between 0.0 and 1.0 inclusive. + https://www.nltk.org/_modules/nltk/translate/bleu_score.html + https://aclanthology.org/P02-1040.pdf + """ + from nltk.translate.bleu_score import ( + SmoothingFunction, # type: ignore + sentence_bleu, + ) + + hypotheses = source[0].split() + references = [s.split() for s in example] + + return float( + sentence_bleu( + references, + hypotheses, + smoothing_function=SmoothingFunction().method1, + auto_reweigh=True, + ) + ) + + +class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): + """Select and order examples based on ngram overlap score (sentence_bleu score + from NLTK package). + + https://www.nltk.org/_modules/nltk/translate/bleu_score.html + https://aclanthology.org/P02-1040.pdf + """ + + examples: List[dict] + """A list of the examples that the prompt template expects.""" + + example_prompt: PromptTemplate + """Prompt template used to format the examples.""" + + threshold: float = -1.0 + """Threshold at which algorithm stops. Set to -1.0 by default. + + For negative threshold: + select_examples sorts examples by ngram_overlap_score, but excludes none. + For threshold greater than 1.0: + select_examples excludes all examples, and returns an empty list. + For threshold equal to 0.0: + select_examples sorts examples by ngram_overlap_score, + and excludes examples with no ngram overlap with input. + """ + + @root_validator(pre=True) + def check_dependencies(cls, values: Dict) -> Dict: + """Check that valid dependencies exist.""" + try: + from nltk.translate.bleu_score import ( # noqa: F401 + SmoothingFunction, + sentence_bleu, + ) + except ImportError as e: + raise ImportError( + "Not all the correct dependencies for this ExampleSelect exist." + "Please install nltk with `pip install nltk`." + ) from e + + return values + + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to list.""" + self.examples.append(example) + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Return list of examples sorted by ngram_overlap_score with input. + + Descending order. + Excludes any examples with ngram_overlap_score less than or equal to threshold. + """ + inputs = list(input_variables.values()) + examples = [] + k = len(self.examples) + score = [0.0] * k + first_prompt_template_key = self.example_prompt.input_variables[0] + + for i in range(k): + score[i] = ngram_overlap_score( + inputs, [self.examples[i][first_prompt_template_key]] + ) + + while True: + arg_max = np.argmax(score) + if (score[arg_max] < self.threshold) or abs( + score[arg_max] - self.threshold + ) < 1e-9: + break + + examples.append(self.examples[arg_max]) + score[arg_max] = self.threshold - 1.0 + + return examples diff --git a/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py b/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py index c439c946ed..db1be277e4 100644 --- a/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py +++ b/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py @@ -1,112 +1,9 @@ -"""Select and order examples based on ngram overlap score (sentence_bleu score). - -https://www.nltk.org/_modules/nltk/translate/bleu_score.html -https://aclanthology.org/P02-1040.pdf -""" -from typing import Dict, List - -import numpy as np -from langchain_core.example_selectors.base import BaseExampleSelector -from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, root_validator - - -def ngram_overlap_score(source: List[str], example: List[str]) -> float: - """Compute ngram overlap score of source and example as sentence_bleu score. - - Use sentence_bleu with method1 smoothing function and auto reweighting. - Return float value between 0.0 and 1.0 inclusive. - https://www.nltk.org/_modules/nltk/translate/bleu_score.html - https://aclanthology.org/P02-1040.pdf - """ - from nltk.translate.bleu_score import ( - SmoothingFunction, # type: ignore - sentence_bleu, - ) - - hypotheses = source[0].split() - references = [s.split() for s in example] - - return float( - sentence_bleu( - references, - hypotheses, - smoothing_function=SmoothingFunction().method1, - auto_reweigh=True, - ) - ) - - -class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): - """Select and order examples based on ngram overlap score (sentence_bleu score). - - https://www.nltk.org/_modules/nltk/translate/bleu_score.html - https://aclanthology.org/P02-1040.pdf - """ - - examples: List[dict] - """A list of the examples that the prompt template expects.""" - - example_prompt: PromptTemplate - """Prompt template used to format the examples.""" - - threshold: float = -1.0 - """Threshold at which algorithm stops. Set to -1.0 by default. - - For negative threshold: - select_examples sorts examples by ngram_overlap_score, but excludes none. - For threshold greater than 1.0: - select_examples excludes all examples, and returns an empty list. - For threshold equal to 0.0: - select_examples sorts examples by ngram_overlap_score, - and excludes examples with no ngram overlap with input. - """ - - @root_validator(pre=True) - def check_dependencies(cls, values: Dict) -> Dict: - """Check that valid dependencies exist.""" - try: - from nltk.translate.bleu_score import ( # noqa: F401 - SmoothingFunction, - sentence_bleu, - ) - except ImportError as e: - raise ImportError( - "Not all the correct dependencies for this ExampleSelect exist." - "Please install nltk with `pip install nltk`." - ) from e - - return values - - def add_example(self, example: Dict[str, str]) -> None: - """Add new example to list.""" - self.examples.append(example) - - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Return list of examples sorted by ngram_overlap_score with input. - - Descending order. - Excludes any examples with ngram_overlap_score less than or equal to threshold. - """ - inputs = list(input_variables.values()) - examples = [] - k = len(self.examples) - score = [0.0] * k - first_prompt_template_key = self.example_prompt.input_variables[0] - - for i in range(k): - score[i] = ngram_overlap_score( - inputs, [self.examples[i][first_prompt_template_key]] - ) - - while True: - arg_max = np.argmax(score) - if (score[arg_max] < self.threshold) or abs( - score[arg_max] - self.threshold - ) < 1e-9: - break - - examples.append(self.examples[arg_max]) - score[arg_max] = self.threshold - 1.0 - - return examples +from langchain_community.example_selectors.ngram_overlap import ( + NGramOverlapExampleSelector, + ngram_overlap_score, +) + +__all__ = [ + "NGramOverlapExampleSelector", + "ngram_overlap_score", +]