diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 9c3d70d778..941a9378cb 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -78,7 +78,7 @@ class EmbeddingStore(BaseModel): embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT)) document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata = sqlalchemy.Column(sqlalchemy.JSON, nullable=True) + cmetadata = sqlalchemy.Column(JSON, nullable=True) # custom_id : any user defined id custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) @@ -233,6 +233,7 @@ class PGVector(VectorStore): self, query: str, k: int = 4, + filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -249,64 +250,83 @@ class PGVector(VectorStore): return self.similarity_search_by_vector( embedding=embedding, k=k, + filter=filter, ) def similarity_search_with_score( - self, query: str, k: int = 4 + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query and score for each """ embedding = self.embedding_function.embed_query(query) - docs = self.similarity_search_with_score_by_vector(embedding, k) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) return docs def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, + filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: with Session(self._conn) as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") - results: List[QueryResult] = ( - session.query( - EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore - ) - .filter(EmbeddingStore.collection_id == collection.uuid) - .order_by(sqlalchemy.asc("distance")) - .join( - CollectionStore, - EmbeddingStore.collection_id == CollectionStore.uuid, - ) - .limit(k) - .all() + filter_by = EmbeddingStore.collection_id == collection.uuid + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value) + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + + results: List[QueryResult] = ( + session.query( + EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore ) - docs = [ - ( - Document( - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, - ), - result.distance if self.embedding_function is not None else None, - ) - for result in results - ] - return docs + .filter(filter_by) + .order_by(sqlalchemy.asc("distance")) + .join( + CollectionStore, + EmbeddingStore.collection_id == CollectionStore.uuid, + ) + .limit(k) + .all() + ) + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result.distance if self.embedding_function is not None else None, + ) + for result in results + ] + return docs def similarity_search_by_vector( self, embedding: List[float], k: int = 4, + filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -314,12 +334,13 @@ class PGVector(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query vector. """ docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k + embedding=embedding, k=k, filter=filter ) return [doc for doc, _ in docs_and_scores] diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 5479a1f5a8..023d04d9ec 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -83,6 +83,56 @@ def test_pgvector_with_metadatas_with_scores() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +def test_pgvector_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_pgvector_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) + ] + + +def test_pgvector_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + def test_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" pgvector = PGVector(