From c96ac3e5910087920a5b716c2d09260e128bdae2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 15 Feb 2023 23:06:48 -0800 Subject: [PATCH] Harrison/semantic subset (#1079) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Chen Wu (吴尘) --- .../example_selector/semantic_similarity.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/langchain/prompts/example_selector/semantic_similarity.py b/langchain/prompts/example_selector/semantic_similarity.py index cf02d995..e4c8e553 100644 --- a/langchain/prompts/example_selector/semantic_similarity.py +++ b/langchain/prompts/example_selector/semantic_similarity.py @@ -24,6 +24,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): """Number of examples to select.""" example_keys: Optional[List[str]] = None """Optional keys to filter examples to.""" + input_keys: Optional[List[str]] = None + """Optional keys to filter input to. If provided, the search is based on + the input variables instead of all variables.""" class Config: """Configuration for this pydantic object.""" @@ -33,13 +36,20 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): def add_example(self, example: Dict[str, str]) -> str: """Add new example to vectorstore.""" - string_example = " ".join(sorted_values(example)) + if self.input_keys: + string_example = " ".join( + sorted_values({key: example[key] for key in self.input_keys}) + ) + else: + string_example = " ".join(sorted_values(example)) ids = self.vectorstore.add_texts([string_example], metadatas=[example]) return ids[0] def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on semantic similarity.""" # Get the docs with the highest similarity. + if self.input_keys: + input_variables = {key: input_variables[key] for key in self.input_keys} query = " ".join(sorted_values(input_variables)) example_docs = self.vectorstore.similarity_search(query, k=self.k) # Get the examples from the metadata. @@ -57,6 +67,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): embeddings: Embeddings, vectorstore_cls: VectorStore, k: int = 4, + input_keys: Optional[List[str]] = None, **vectorstore_cls_kwargs: Any, ) -> SemanticSimilarityExampleSelector: """Create k-shot example selector using example list and embeddings. @@ -68,16 +79,24 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). vectorstore_cls: A vector store DB interface class, e.g. FAISS. k: Number of examples to select + input_keys: If provided, the search is based on the input variables + instead of all variables. vectorstore_cls_kwargs: optional kwargs containing url for vector store Returns: The ExampleSelector instantiated, backed by a vector store. """ - string_examples = [" ".join(sorted_values(eg)) for eg in examples] + if input_keys: + string_examples = [ + " ".join(sorted_values({k: eg[k] for k in input_keys})) + for eg in examples + ] + else: + string_examples = [" ".join(sorted_values(eg)) for eg in examples] vectorstore = vectorstore_cls.from_texts( string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs ) - return cls(vectorstore=vectorstore, k=k) + return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, BaseModel): @@ -93,6 +112,8 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on semantic similarity.""" # Get the docs with the highest similarity. + if self.input_keys: + input_variables = {key: input_variables[key] for key in self.input_keys} query = " ".join(sorted_values(input_variables)) example_docs = self.vectorstore.max_marginal_relevance_search( query, k=self.k, fetch_k=self.fetch_k @@ -112,6 +133,7 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas embeddings: Embeddings, vectorstore_cls: VectorStore, k: int = 4, + input_keys: Optional[List[str]] = None, fetch_k: int = 20, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: @@ -124,13 +146,21 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). vectorstore_cls: A vector store DB interface class, e.g. FAISS. k: Number of examples to select + input_keys: If provided, the search is based on the input variables + instead of all variables. vectorstore_cls_kwargs: optional kwargs containing url for vector store Returns: The ExampleSelector instantiated, backed by a vector store. """ - string_examples = [" ".join(sorted_values(eg)) for eg in examples] + if input_keys: + string_examples = [ + " ".join(sorted_values({k: eg[k] for k in input_keys})) + for eg in examples + ] + else: + string_examples = [" ".join(sorted_values(eg)) for eg in examples] vectorstore = vectorstore_cls.from_texts( string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs ) - return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k) + return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)