forked from Archives/langchain
Add metadata filter to PGVector search (#1872)
Add ability to filter pgvector documents by metadata.
This commit is contained in:
parent
d3d4503ce2
commit
f155d9d3ec
@ -78,7 +78,7 @@ class EmbeddingStore(BaseModel):
|
|||||||
|
|
||||||
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
|
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
|
||||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
cmetadata = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
|
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||||
|
|
||||||
# custom_id : any user defined id
|
# custom_id : any user defined id
|
||||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
@ -233,6 +233,7 @@ class PGVector(VectorStore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Run similarity search with PGVector with distance.
|
"""Run similarity search with PGVector with distance.
|
||||||
@ -249,40 +250,58 @@ class PGVector(VectorStore):
|
|||||||
return self.similarity_search_by_vector(
|
return self.similarity_search_by_vector(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
k=k,
|
k=k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
self, query: str, k: int = 4
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Text to look up documents similar to.
|
query: Text to look up documents similar to.
|
||||||
k: Number of Documents to return. Defaults to 4.
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query and score for each
|
List of Documents most similar to the query and score for each
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function.embed_query(query)
|
embedding = self.embedding_function.embed_query(query)
|
||||||
docs = self.similarity_search_with_score_by_vector(embedding, k)
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, filter=filter
|
||||||
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def similarity_search_with_score_by_vector(
|
def similarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
with Session(self._conn) as session:
|
with Session(self._conn) as session:
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
|
|
||||||
|
filter_by = EmbeddingStore.collection_id == collection.uuid
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
filter_clauses = []
|
||||||
|
for key, value in filter.items():
|
||||||
|
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value)
|
||||||
|
filter_clauses.append(filter_by_metadata)
|
||||||
|
|
||||||
|
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
|
||||||
|
|
||||||
results: List[QueryResult] = (
|
results: List[QueryResult] = (
|
||||||
session.query(
|
session.query(
|
||||||
EmbeddingStore,
|
EmbeddingStore,
|
||||||
self.distance_strategy(embedding).label("distance"), # type: ignore
|
self.distance_strategy(embedding).label("distance"), # type: ignore
|
||||||
)
|
)
|
||||||
.filter(EmbeddingStore.collection_id == collection.uuid)
|
.filter(filter_by)
|
||||||
.order_by(sqlalchemy.asc("distance"))
|
.order_by(sqlalchemy.asc("distance"))
|
||||||
.join(
|
.join(
|
||||||
CollectionStore,
|
CollectionStore,
|
||||||
@ -307,6 +326,7 @@ class PGVector(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to embedding vector.
|
"""Return docs most similar to embedding vector.
|
||||||
@ -314,12 +334,13 @@ class PGVector(VectorStore):
|
|||||||
Args:
|
Args:
|
||||||
embedding: Embedding to look up documents similar to.
|
embedding: Embedding to look up documents similar to.
|
||||||
k: Number of Documents to return. Defaults to 4.
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query vector.
|
List of Documents most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
embedding=embedding, k=k
|
embedding=embedding, k=k, filter=filter
|
||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
@ -83,6 +83,56 @@ def test_pgvector_with_metadatas_with_scores() -> None:
|
|||||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_with_filter_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = PGVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
|
||||||
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_with_filter_distant_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = PGVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
|
||||||
|
assert output == [
|
||||||
|
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_with_filter_no_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = PGVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
def test_pgvector_collection_with_metadata() -> None:
|
def test_pgvector_collection_with_metadata() -> None:
|
||||||
"""Test end to end collection construction"""
|
"""Test end to end collection construction"""
|
||||||
pgvector = PGVector(
|
pgvector = PGVector(
|
||||||
|
Loading…
Reference in New Issue
Block a user