Harrison/semantic subset (#1079)

Co-authored-by: Chen Wu (吴尘) <henrychenwu@cmu.edu>
searx-api
Harrison Chase 1 year ago committed by GitHub
parent 19c2797bed
commit c96ac3e591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
"""Number of examples to select."""
example_keys: Optional[List[str]] = None
"""Optional keys to filter examples to."""
input_keys: Optional[List[str]] = None
"""Optional keys to filter input to. If provided, the search is based on
the input variables instead of all variables."""
class Config:
"""Configuration for this pydantic object."""
@ -33,13 +36,20 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
def add_example(self, example: Dict[str, str]) -> str:
"""Add new example to vectorstore."""
string_example = " ".join(sorted_values(example))
if self.input_keys:
string_example = " ".join(
sorted_values({key: example[key] for key in self.input_keys})
)
else:
string_example = " ".join(sorted_values(example))
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
return ids[0]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.similarity_search(query, k=self.k)
# Get the examples from the metadata.
@ -57,6 +67,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
embeddings: Embeddings,
vectorstore_cls: VectorStore,
k: int = 4,
input_keys: Optional[List[str]] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings.
@ -68,16 +79,24 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k)
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, BaseModel):
@ -93,6 +112,8 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.max_marginal_relevance_search(
query, k=self.k, fetch_k=self.fetch_k
@ -112,6 +133,7 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
embeddings: Embeddings,
vectorstore_cls: VectorStore,
k: int = 4,
input_keys: Optional[List[str]] = None,
fetch_k: int = 20,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:
@ -124,13 +146,21 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k)
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)

Loading…
Cancel
Save