langchain/libs/community/langchain_community/example_selectors/ngram_overlap.py
Leonid Ganeline c2ca6612fe
refactor langchain.prompts.example_selector (#15369)
The `langchain.prompts.example_selector` [still holds several
artifacts](https://api.python.langchain.com/en/latest/langchain_api_reference.html#module-langchain.prompts)
that belongs to `community`. If they moved to
`langchain_community.example_selectors`, the `langchain.prompts`
namespace would be effectively removed which is great.
- moved a class and afunction to `langchain_community`

Note:
- Previously, the `langchain.prompts.example_selector` artifacts were
moved into the `langchain_core.exampe_selectors`. See the flattened
namespace (`.prompts` was removed)!
Similar flattening was implemented for the `langchain_core` as the
`langchain_core.exampe_selectors`.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-02-01 12:05:57 -08:00

115 lines
3.8 KiB
Python

"""Select and order examples based on ngram overlap score (sentence_bleu score).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from typing import Dict, List
import numpy as np
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
"""Compute ngram overlap score of source and example as sentence_bleu score
from NLTK package.
Use sentence_bleu with method1 smoothing function and auto reweighting.
Return float value between 0.0 and 1.0 inclusive.
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore
sentence_bleu,
)
hypotheses = source[0].split()
references = [s.split() for s in example]
return float(
sentence_bleu(
references,
hypotheses,
smoothing_function=SmoothingFunction().method1,
auto_reweigh=True,
)
)
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
"""Select and order examples based on ngram overlap score (sentence_bleu score
from NLTK package).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
threshold: float = -1.0
"""Threshold at which algorithm stops. Set to -1.0 by default.
For negative threshold:
select_examples sorts examples by ngram_overlap_score, but excludes none.
For threshold greater than 1.0:
select_examples excludes all examples, and returns an empty list.
For threshold equal to 0.0:
select_examples sorts examples by ngram_overlap_score,
and excludes examples with no ngram overlap with input.
"""
@root_validator(pre=True)
def check_dependencies(cls, values: Dict) -> Dict:
"""Check that valid dependencies exist."""
try:
from nltk.translate.bleu_score import ( # noqa: F401
SmoothingFunction,
sentence_bleu,
)
except ImportError as e:
raise ImportError(
"Not all the correct dependencies for this ExampleSelect exist."
"Please install nltk with `pip install nltk`."
) from e
return values
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Return list of examples sorted by ngram_overlap_score with input.
Descending order.
Excludes any examples with ngram_overlap_score less than or equal to threshold.
"""
inputs = list(input_variables.values())
examples = []
k = len(self.examples)
score = [0.0] * k
first_prompt_template_key = self.example_prompt.input_variables[0]
for i in range(k):
score[i] = ngram_overlap_score(
inputs, [self.examples[i][first_prompt_template_key]]
)
while True:
arg_max = np.argmax(score)
if (score[arg_max] < self.threshold) or abs(
score[arg_max] - self.threshold
) < 1e-9:
break
examples.append(self.examples[arg_max])
score[arg_max] = self.threshold - 1.0
return examples