Add metadata filter to PGVector search (#1872)

Add ability to filter pgvector documents by metadata.
This commit is contained in:
Maurício Maia 2023-03-22 19:21:40 -03:00 committed by GitHub
parent d3d4503ce2
commit f155d9d3ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 28 deletions

View File

@ -78,7 +78,7 @@ class EmbeddingStore(BaseModel):
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT)) embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True) document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(sqlalchemy.JSON, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id # custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
@ -233,6 +233,7 @@ class PGVector(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Run similarity search with PGVector with distance. """Run similarity search with PGVector with distance.
@ -249,64 +250,83 @@ class PGVector(VectorStore):
return self.similarity_search_by_vector( return self.similarity_search_by_vector(
embedding=embedding, embedding=embedding,
k=k, k=k,
filter=filter,
) )
def similarity_search_with_score( def similarity_search_with_score(
self, query: str, k: int = 4 self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns: Returns:
List of Documents most similar to the query and score for each List of Documents most similar to the query and score for each
""" """
embedding = self.embedding_function.embed_query(query) embedding = self.embedding_function.embed_query(query)
docs = self.similarity_search_with_score_by_vector(embedding, k) docs = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return docs return docs
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
with Session(self._conn) as session: with Session(self._conn) as session:
collection = self.get_collection(session) collection = self.get_collection(session)
if not collection: if not collection:
raise ValueError("Collection not found") raise ValueError("Collection not found")
results: List[QueryResult] = ( filter_by = EmbeddingStore.collection_id == collection.uuid
session.query(
EmbeddingStore, if filter is not None:
self.distance_strategy(embedding).label("distance"), # type: ignore filter_clauses = []
) for key, value in filter.items():
.filter(EmbeddingStore.collection_id == collection.uuid) filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value)
.order_by(sqlalchemy.asc("distance")) filter_clauses.append(filter_by_metadata)
.join(
CollectionStore, filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
EmbeddingStore.collection_id == CollectionStore.uuid,
) results: List[QueryResult] = (
.limit(k) session.query(
.all() EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
) )
docs = [ .filter(filter_by)
( .order_by(sqlalchemy.asc("distance"))
Document( .join(
page_content=result.EmbeddingStore.document, CollectionStore,
metadata=result.EmbeddingStore.cmetadata, EmbeddingStore.collection_id == CollectionStore.uuid,
), )
result.distance if self.embedding_function is not None else None, .limit(k)
) .all()
for result in results )
] docs = [
return docs (
Document(
page_content=result.EmbeddingStore.document,
metadata=result.EmbeddingStore.cmetadata,
),
result.distance if self.embedding_function is not None else None,
)
for result in results
]
return docs
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
@ -314,12 +334,13 @@ class PGVector(VectorStore):
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns: Returns:
List of Documents most similar to the query vector. List of Documents most similar to the query vector.
""" """
docs_and_scores = self.similarity_search_with_score_by_vector( docs_and_scores = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k embedding=embedding, k=k, filter=filter
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]

View File

@ -83,6 +83,56 @@ def test_pgvector_with_metadatas_with_scores() -> None:
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
def test_pgvector_with_filter_match() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection_filter",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
def test_pgvector_with_filter_distant_match() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection_filter",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
assert output == [
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406)
]
def test_pgvector_with_filter_no_match() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection_filter",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
assert output == []
def test_pgvector_collection_with_metadata() -> None: def test_pgvector_collection_with_metadata() -> None:
"""Test end to end collection construction""" """Test end to end collection construction"""
pgvector = PGVector( pgvector = PGVector(