Add add_example method to all ExampleSelector classes, with tests (#178)

Also updated docs, and noticed an issue with the add_texts method on
VectorStores that I had missed before -- the metadatas arg should be
required to match the classmethod which initializes the VectorStores
(the add_example methods break otherwise in the ExampleSelectors)
harrison/flexible_model_args
Samantha Whitmore 2 years ago committed by GitHub
parent 780ef84cf0
commit 09f301cd38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -310,7 +310,7 @@
" example_prompt=example_prompt, \n",
" # This is the maximum length that the formatted examples should be.\n",
" # Length is measured by the get_text_length function below.\n",
" max_length=18,\n",
" max_length=25,\n",
" # This is the function used to get the length of a string, which is used\n",
" # to determine which examples to include. It is commented out because\n",
" # it is provided as a default value if none is specified.\n",
@ -378,17 +378,59 @@
"Input: happy\n",
"Output: sad\n",
"\n",
"Input: big and huge and massive and large and gigantic and tall and bigger than everything else\n",
"Input: big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\n",
"Output:\n"
]
}
],
"source": [
"# An example with long input, so it selects only one example.\n",
"long_string = \"big and huge and massive and large and gigantic and tall and bigger than everything else\"\n",
"long_string = \"big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\"\n",
"print(dynamic_prompt.format(adjective=long_string))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e4bebcd9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the antonym of every input\n",
"\n",
"Input: happy\n",
"Output: sad\n",
"\n",
"Input: tall\n",
"Output: short\n",
"\n",
"Input: energetic\n",
"Output: lethargic\n",
"\n",
"Input: sunny\n",
"Output: gloomy\n",
"\n",
"Input: windy\n",
"Output: calm\n",
"\n",
"Input: big\n",
"Output: small\n",
"\n",
"Input: enthusiastic\n",
"Output:\n"
]
}
],
"source": [
"# You can add an example to an example selector as well.\n",
"new_example = {\"input\": \"big\", \"output\": \"small\"}\n",
"dynamic_prompt.example_selector.add_example(new_example)\n",
"print(dynamic_prompt.format(adjective=\"enthusiastic\"))"
]
},
{
"cell_type": "markdown",
"id": "2d007b0a",
@ -401,7 +443,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "241bfe80",
"metadata": {},
"outputs": [],
@ -413,7 +455,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "50d0a701",
"metadata": {},
"outputs": [],
@ -440,7 +482,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "4c8fdf45",
"metadata": {},
"outputs": [
@ -465,9 +507,11 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "829af21a",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
@ -484,10 +528,36 @@
}
],
"source": [
"# Input is a measurment, so should select the tall/short example\n",
"# Input is a measurement, so should select the tall/short example\n",
"print(similar_prompt.format(adjective=\"fat\"))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3c16fe23",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Give the antonym of every input\n",
"\n",
"Input: enthusiastic\n",
"Output: apathetic\n",
"\n",
"Input: joyful\n",
"Output:\n"
]
}
],
"source": [
"# You can add new examples to the SemanticSimilarityExampleSelector as well\n",
"similar_prompt.example_selector.add_example({\"input\": \"enthusiastic\", \"output\": \"apathetic\"})\n",
"print(similar_prompt.format(adjective=\"joyful\"))"
]
},
{
"cell_type": "markdown",
"id": "dbc32551",
@ -532,7 +602,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.4"
}
},
"nbformat": 4,

@ -6,6 +6,10 @@ from typing import Dict, List
class BaseExampleSelector(ABC):
"""Interface for selecting examples to include in prompts."""
@abstractmethod
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to store for a key."""
@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the inputs."""

@ -25,6 +25,12 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
example_text_lengths: List[int] = [] #: :meta private:
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
string_example = self.example_prompt.format(**example)
self.example_text_lengths.append(self.get_text_length(string_example))
@validator("example_text_lengths", always=True)
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
"""Calculate text lengths if they don't exist."""

@ -24,6 +24,11 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to vectorstore."""
string_example = " ".join(example.values())
self.vectorstore.add_texts([string_example], metadatas=[example])
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.

@ -10,7 +10,9 @@ class VectorStore(ABC):
"""Interface for vector stores."""
@abstractmethod
def add_texts(self, texts: Iterable[str]) -> None:
def add_texts(
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
) -> None:
"""Run more texts through the embeddings and add to the vectorstore."""
@abstractmethod

@ -65,7 +65,9 @@ class ElasticVectorSearch(VectorStore):
)
self.client = es_client
def add_texts(self, texts: Iterable[str]) -> None:
def add_texts(
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
) -> None:
"""Run more texts through the embeddings and add to the vectorstore."""
try:
from elasticsearch.helpers import bulk
@ -76,11 +78,13 @@ class ElasticVectorSearch(VectorStore):
)
requests = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
request = {
"_op_type": "index",
"_index": self.index_name,
"vector": self.embedding_function(text),
"text": text,
"metadata": metadata,
}
requests.append(request)
bulk(self.client, requests)

@ -37,7 +37,9 @@ class FAISS(VectorStore):
self.docstore = docstore
self.index_to_docstore_id = index_to_docstore_id
def add_texts(self, texts: Iterable[str]) -> None:
def add_texts(
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
) -> None:
"""Run more texts through the embeddings and add to the vectorstore."""
if not isinstance(self.docstore, AddableMixin):
raise ValueError(
@ -46,7 +48,10 @@ class FAISS(VectorStore):
)
# Embed and create the documents.
embeddings = [self.embedding_function(text) for text in texts]
documents = [Document(page_content=text) for text in texts]
documents = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
# Add to the index, the index_to_id mapping, and the docstore.
starting_len = len(self.index_to_docstore_id)
self.index.add(np.array(embeddings, dtype=np.float32))

@ -29,6 +29,15 @@ def test_dynamic_prompt_valid(selector: LengthBasedExampleSelector) -> None:
assert output == EXAMPLES
def test_dynamic_prompt_add_example(selector: LengthBasedExampleSelector) -> None:
"""Test dynamic prompt can add an example."""
new_example = {"question": "Question: what are you?\nAnswer: bar"}
selector.add_example(new_example)
short_question = "Short question?"
output = selector.select_examples({"question": short_question})
assert output == EXAMPLES + [new_example]
def test_dynamic_prompt_trims_one_example(selector: LengthBasedExampleSelector) -> None:
"""Test dynamic prompt can trim one example."""
long_question = """I am writing a really long question,

Loading…
Cancel
Save