diff --git a/docs/examples/prompts/prompt_management.ipynb b/docs/examples/prompts/prompt_management.ipynb index 96c20e20df..6d11f6df5e 100644 --- a/docs/examples/prompts/prompt_management.ipynb +++ b/docs/examples/prompts/prompt_management.ipynb @@ -310,7 +310,7 @@ " example_prompt=example_prompt, \n", " # This is the maximum length that the formatted examples should be.\n", " # Length is measured by the get_text_length function below.\n", - " max_length=18,\n", + " max_length=25,\n", " # This is the function used to get the length of a string, which is used\n", " # to determine which examples to include. It is commented out because\n", " # it is provided as a default value if none is specified.\n", @@ -378,17 +378,59 @@ "Input: happy\n", "Output: sad\n", "\n", - "Input: big and huge and massive and large and gigantic and tall and bigger than everything else\n", + "Input: big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\n", "Output:\n" ] } ], "source": [ "# An example with long input, so it selects only one example.\n", - "long_string = \"big and huge and massive and large and gigantic and tall and bigger than everything else\"\n", + "long_string = \"big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\"\n", "print(dynamic_prompt.format(adjective=long_string))" ] }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e4bebcd9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the antonym of every input\n", + "\n", + "Input: happy\n", + "Output: sad\n", + "\n", + "Input: tall\n", + "Output: short\n", + "\n", + "Input: energetic\n", + "Output: lethargic\n", + "\n", + "Input: sunny\n", + "Output: gloomy\n", + "\n", + "Input: windy\n", + "Output: calm\n", + "\n", + "Input: big\n", + "Output: small\n", + "\n", + "Input: enthusiastic\n", + "Output:\n" + ] + } + ], + "source": [ + "# You can add an example to an example selector as well.\n", + "new_example = {\"input\": \"big\", \"output\": \"small\"}\n", + "dynamic_prompt.example_selector.add_example(new_example)\n", + "print(dynamic_prompt.format(adjective=\"enthusiastic\"))" + ] + }, { "cell_type": "markdown", "id": "2d007b0a", @@ -401,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "241bfe80", "metadata": {}, "outputs": [], @@ -413,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "50d0a701", "metadata": {}, "outputs": [], @@ -440,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "4c8fdf45", "metadata": {}, "outputs": [ @@ -465,9 +507,11 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "829af21a", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -484,10 +528,36 @@ } ], "source": [ - "# Input is a measurment, so should select the tall/short example\n", + "# Input is a measurement, so should select the tall/short example\n", "print(similar_prompt.format(adjective=\"fat\"))" ] }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3c16fe23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Give the antonym of every input\n", + "\n", + "Input: enthusiastic\n", + "Output: apathetic\n", + "\n", + "Input: joyful\n", + "Output:\n" + ] + } + ], + "source": [ + "# You can add new examples to the SemanticSimilarityExampleSelector as well\n", + "similar_prompt.example_selector.add_example({\"input\": \"enthusiastic\", \"output\": \"apathetic\"})\n", + "print(similar_prompt.format(adjective=\"joyful\"))" + ] + }, { "cell_type": "markdown", "id": "dbc32551", @@ -532,7 +602,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.10.4" } }, "nbformat": 4, diff --git a/langchain/prompts/example_selector/base.py b/langchain/prompts/example_selector/base.py index 9af9e307fe..00b91cb428 100644 --- a/langchain/prompts/example_selector/base.py +++ b/langchain/prompts/example_selector/base.py @@ -6,6 +6,10 @@ from typing import Dict, List class BaseExampleSelector(ABC): """Interface for selecting examples to include in prompts.""" + @abstractmethod + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to store for a key.""" + @abstractmethod def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on the inputs.""" diff --git a/langchain/prompts/example_selector/length_based.py b/langchain/prompts/example_selector/length_based.py index ae13e884a9..086b72ce55 100644 --- a/langchain/prompts/example_selector/length_based.py +++ b/langchain/prompts/example_selector/length_based.py @@ -25,6 +25,12 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): example_text_lengths: List[int] = [] #: :meta private: + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to list.""" + self.examples.append(example) + string_example = self.example_prompt.format(**example) + self.example_text_lengths.append(self.get_text_length(string_example)) + @validator("example_text_lengths", always=True) def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: """Calculate text lengths if they don't exist.""" diff --git a/langchain/prompts/example_selector/semantic_similarity.py b/langchain/prompts/example_selector/semantic_similarity.py index 82c087dfa6..499bd9fc7a 100644 --- a/langchain/prompts/example_selector/semantic_similarity.py +++ b/langchain/prompts/example_selector/semantic_similarity.py @@ -24,6 +24,11 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to vectorstore.""" + string_example = " ".join(example.values()) + self.vectorstore.add_texts([string_example], metadatas=[example]) + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on semantic similarity.""" # Get the docs with the highest similarity. diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 8c9b171c49..066c9a01d8 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -10,7 +10,9 @@ class VectorStore(ABC): """Interface for vector stores.""" @abstractmethod - def add_texts(self, texts: Iterable[str]) -> None: + def add_texts( + self, texts: Iterable[str], metadatas: Optional[List[dict]] = None + ) -> None: """Run more texts through the embeddings and add to the vectorstore.""" @abstractmethod diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 9194636405..8620c559fc 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -65,7 +65,9 @@ class ElasticVectorSearch(VectorStore): ) self.client = es_client - def add_texts(self, texts: Iterable[str]) -> None: + def add_texts( + self, texts: Iterable[str], metadatas: Optional[List[dict]] = None + ) -> None: """Run more texts through the embeddings and add to the vectorstore.""" try: from elasticsearch.helpers import bulk @@ -76,11 +78,13 @@ class ElasticVectorSearch(VectorStore): ) requests = [] for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} request = { "_op_type": "index", "_index": self.index_name, "vector": self.embedding_function(text), "text": text, + "metadata": metadata, } requests.append(request) bulk(self.client, requests) diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 2b3e4d611d..7d26ad12ab 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -37,7 +37,9 @@ class FAISS(VectorStore): self.docstore = docstore self.index_to_docstore_id = index_to_docstore_id - def add_texts(self, texts: Iterable[str]) -> None: + def add_texts( + self, texts: Iterable[str], metadatas: Optional[List[dict]] = None + ) -> None: """Run more texts through the embeddings and add to the vectorstore.""" if not isinstance(self.docstore, AddableMixin): raise ValueError( @@ -46,7 +48,10 @@ class FAISS(VectorStore): ) # Embed and create the documents. embeddings = [self.embedding_function(text) for text in texts] - documents = [Document(page_content=text) for text in texts] + documents = [] + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + documents.append(Document(page_content=text, metadata=metadata)) # Add to the index, the index_to_id mapping, and the docstore. starting_len = len(self.index_to_docstore_id) self.index.add(np.array(embeddings, dtype=np.float32)) diff --git a/tests/unit_tests/prompts/test_length_based_example_selector.py b/tests/unit_tests/prompts/test_length_based_example_selector.py index 7f9ca3cd7a..def2413230 100644 --- a/tests/unit_tests/prompts/test_length_based_example_selector.py +++ b/tests/unit_tests/prompts/test_length_based_example_selector.py @@ -29,6 +29,15 @@ def test_dynamic_prompt_valid(selector: LengthBasedExampleSelector) -> None: assert output == EXAMPLES +def test_dynamic_prompt_add_example(selector: LengthBasedExampleSelector) -> None: + """Test dynamic prompt can add an example.""" + new_example = {"question": "Question: what are you?\nAnswer: bar"} + selector.add_example(new_example) + short_question = "Short question?" + output = selector.select_examples({"question": short_question}) + assert output == EXAMPLES + [new_example] + + def test_dynamic_prompt_trims_one_example(selector: LengthBasedExampleSelector) -> None: """Test dynamic prompt can trim one example.""" long_question = """I am writing a really long question,