forked from Archives/langchain
Merge branch 'ankush/retry-openai' into ankush/async-llm
commit
738bf977ab
@ -0,0 +1,70 @@
|
||||
"""Wrapper around TensorflowHub embedding models."""
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
|
||||
|
||||
|
||||
class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around tensorflow_hub embedding models.
|
||||
|
||||
To use, you should have the ``tensorflow_text`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import TensorflowHubEmbeddings
|
||||
url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
|
||||
tf = TensorflowHubEmbeddings(model_url=url)
|
||||
"""
|
||||
|
||||
embed: Any #: :meta private:
|
||||
model_url: str = DEFAULT_MODEL_URL
|
||||
"""Model name to use."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the tensorflow_hub and tensorflow_text."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import tensorflow_hub
|
||||
import tensorflow_text # noqa
|
||||
|
||||
self.embed = tensorflow_hub.load(self.model_url)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import some python packages." "Please install them."
|
||||
) from e
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
embeddings = self.embed(texts).numpy()
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embedding = self.embed(text).numpy()[0]
|
||||
return embedding.tolist()
|
@ -0,0 +1,112 @@
|
||||
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
|
||||
"""Compute ngram overlap score of source and example as sentence_bleu score.
|
||||
|
||||
Use sentence_bleu with method1 smoothing function and auto reweighting.
|
||||
Return float value between 0.0 and 1.0 inclusive.
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
from nltk.translate.bleu_score import ( # type: ignore
|
||||
SmoothingFunction,
|
||||
sentence_bleu,
|
||||
)
|
||||
|
||||
hypotheses = source[0].split()
|
||||
references = [s.split() for s in example]
|
||||
|
||||
return float(
|
||||
sentence_bleu(
|
||||
references,
|
||||
hypotheses,
|
||||
smoothing_function=SmoothingFunction().method1,
|
||||
auto_reweigh=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
|
||||
examples: List[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""Prompt template used to format the examples."""
|
||||
|
||||
threshold: float = -1.0
|
||||
"""Threshold at which algorithm stops. Set to -1.0 by default.
|
||||
|
||||
For negative threshold:
|
||||
select_examples sorts examples by ngram_overlap_score, but excludes none.
|
||||
For threshold greater than 1.0:
|
||||
select_examples excludes all examples, and returns an empty list.
|
||||
For threshold equal to 0.0:
|
||||
select_examples sorts examples by ngram_overlap_score,
|
||||
and excludes examples with no ngram overlap with input.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_dependencies(cls, values: Dict) -> Dict:
|
||||
"""Check that valid dependencies exist."""
|
||||
try:
|
||||
from nltk.translate.bleu_score import ( # noqa: disable=F401
|
||||
SmoothingFunction,
|
||||
sentence_bleu,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Not all the correct dependencies for this ExampleSelect exist"
|
||||
) from e
|
||||
|
||||
return values
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
self.examples.append(example)
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Return list of examples sorted by ngram_overlap_score with input.
|
||||
|
||||
Descending order.
|
||||
Excludes any examples with ngram_overlap_score less than or equal to threshold.
|
||||
"""
|
||||
inputs = list(input_variables.values())
|
||||
examples = []
|
||||
k = len(self.examples)
|
||||
score = [0.0] * k
|
||||
first_prompt_template_key = self.example_prompt.input_variables[0]
|
||||
|
||||
for i in range(k):
|
||||
score[i] = ngram_overlap_score(
|
||||
inputs, [self.examples[i][first_prompt_template_key]]
|
||||
)
|
||||
|
||||
while True:
|
||||
arg_max = np.argmax(score)
|
||||
if (score[arg_max] < self.threshold) or abs(
|
||||
score[arg_max] - self.threshold
|
||||
) < 1e-9:
|
||||
break
|
||||
|
||||
examples.append(self.examples[arg_max])
|
||||
score[arg_max] = self.threshold - 1.0
|
||||
|
||||
return examples
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,19 @@
|
||||
"""Test TensorflowHub embeddings."""
|
||||
from langchain.embeddings import TensorflowHubEmbeddings
|
||||
|
||||
|
||||
def test_tensorflowhub_embedding_documents() -> None:
|
||||
"""Test tensorflowhub embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = TensorflowHubEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 512
|
||||
|
||||
|
||||
def test_tensorflowhub_embedding_query() -> None:
|
||||
"""Test tensorflowhub embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = TensorflowHubEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 512
|
@ -0,0 +1,73 @@
|
||||
"""Test functionality related to ngram overlap based selector."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.example_selector.ngram_overlap import (
|
||||
NGramOverlapExampleSelector,
|
||||
ngram_overlap_score,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
EXAMPLES = [
|
||||
{"input": "See Spot run.", "output": "foo1"},
|
||||
{"input": "My dog barks.", "output": "foo2"},
|
||||
{"input": "Spot can run.", "output": "foo3"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selector() -> NGramOverlapExampleSelector:
|
||||
"""Get ngram overlap based selector to use in tests."""
|
||||
prompts = PromptTemplate(
|
||||
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
|
||||
)
|
||||
selector = NGramOverlapExampleSelector(
|
||||
examples=EXAMPLES,
|
||||
example_prompt=prompts,
|
||||
)
|
||||
return selector
|
||||
|
||||
|
||||
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can select examples."""
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can add an example."""
|
||||
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
|
||||
selector.add_example(new_example)
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
|
||||
selector.threshold = 0.0
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]]
|
||||
|
||||
|
||||
def test_selector_threshold_more_than_one(
|
||||
selector: NGramOverlapExampleSelector,
|
||||
) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests that ngram_overlap_score returns correct values."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
|
||||
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
|
||||
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
|
||||
|
||||
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
|
||||
assert check == [True, True, True]
|
Loading…
Reference in New Issue