From 6cdca4355d095e39813b003c68786a38e4cbda5b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Mar 2024 16:56:00 -0400 Subject: [PATCH] community[minor]: Revamp PGVector Filtering (#18992) This PR makes the following updates in the pgvector database: 1. Use JSONB field for metadata instead of JSON 2. Update operator syntax to include required `$` prefix before the operators (otherwise there will be name collisions with fields) 3. The change is non-breaking, old functionality is still the default, but it will emit a deprecation warning 4. Previous functionality has bugs associated with comparisons due to casting to text (so lexical ordering is used incorrectly for numeric fields) 5. Adds an a GIN index on the JSONB field for more efficient querying --- .../vectorstores/pgvector.py | 403 ++++++++++++++++-- .../fixtures/filtering_test_cases.py | 222 ++++++++++ .../vectorstores/test_pgvector.py | 271 +++++++++++- 3 files changed, 851 insertions(+), 45 deletions(-) create mode 100644 libs/community/tests/integration_tests/vectorstores/fixtures/filtering_test_cases.py diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index 755c72bb4d..fb4bb6bb6a 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import enum +import json import logging import uuid from typing import ( @@ -18,8 +19,9 @@ from typing import ( import numpy as np import sqlalchemy -from sqlalchemy import delete -from sqlalchemy.dialects.postgresql import JSON, UUID +from langchain_core._api import warn_deprecated +from sqlalchemy import SQLColumnExpression, delete, func +from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID from sqlalchemy.orm import Session, relationship try: @@ -61,8 +63,39 @@ class BaseModel(Base): _classes: Any = None +COMPARISONS_TO_NATIVE = { + "$eq": "==", + "$ne": "!=", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: +def _get_embedding_collection_store( + vector_dimension: Optional[int] = None, *, use_jsonb: bool = True +) -> Any: global _classes if _classes is not None: return _classes @@ -111,26 +144,60 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A created = True return collection, created - class EmbeddingStore(BaseModel): - """Embedding store.""" + if use_jsonb: + # TODO(PRIOR TO LANDING): Create a gin index on the cmetadata field + class EmbeddingStore(BaseModel): + """Embedding store.""" - __tablename__ = "langchain_pg_embedding" + __tablename__ = "langchain_pg_embedding" - collection_id = sqlalchemy.Column( - UUID(as_uuid=True), - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship(CollectionStore, back_populates="embeddings") + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSONB, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + __table_args__ = ( + sqlalchemy.Index( + "ix_cmetadata_gin", + "cmetadata", + postgresql_using="gin", + postgresql_ops={"cmetadata": "jsonb_path_ops"}, + ), + ) + else: + # For backwards comaptibilty with older versions of pgvector + # This should be removed in the future (remove during migration) + class EmbeddingStore(BaseModel): # type: ignore[no-redef] + """Embedding store.""" + + __tablename__ = "langchain_pg_embedding" + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") - embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata = sqlalchemy.Column(JSON, nullable=True) + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSON, nullable=True) - # custom_id : any user defined id - custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) _classes = (EmbeddingStore, CollectionStore) @@ -163,6 +230,11 @@ class PGVector(VectorStore): pre_delete_collection: If True, will delete the collection if it exists. (default: False). Useful for testing. engine_args: SQLAlchemy's create engine arguments. + use_jsonb: Use JSONB instead of JSON for metadata. (default: True) + Strongly discouraged from using JSON as it's not as efficient + for querying. + It's provided here for backwards compatibility with older versions, + and will be removed in the future. Example: .. code-block:: python @@ -178,9 +250,8 @@ class PGVector(VectorStore): documents=docs, collection_name=COLLECTION_NAME, connection_string=CONNECTION_STRING, + use_jsonb=True, ) - - """ def __init__( @@ -197,7 +268,9 @@ class PGVector(VectorStore): *, connection: Optional[sqlalchemy.engine.Connection] = None, engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = False, ) -> None: + """Initialize the PGVector store.""" self.connection_string = connection_string self.embedding_function = embedding_function self._embedding_length = embedding_length @@ -209,6 +282,29 @@ class PGVector(VectorStore): self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} self._bind = connection if connection else self._create_engine() + self.use_jsonb = use_jsonb + + if not use_jsonb: + # Replace with a deprecation warning. + warn_deprecated( + "0.0.29", + pending=True, + message=( + "Please use JSONB instead of JSON for metadata. " + "This change will allow for more efficient querying that " + "involves filtering based on metadata." + "Please note that filtering operators have been changed " + "when using JSOB metadata to be prefixed with a $ sign " + "to avoid name collisions with columns. " + "If you're using an existing database, you will need to create a" + "db migration for your metadata column to be JSONB and update your " + "queries to use the new operators. " + ), + alternative=( + "Instantiate with use_jsonb=True to use JSONB instead " + "of JSON for metadata." + ), + ) self.__post_init__() def __post_init__( @@ -218,7 +314,7 @@ class PGVector(VectorStore): self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + self._embedding_length, use_jsonb=self.use_jsonb ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore @@ -336,6 +432,8 @@ class PGVector(VectorStore): distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, connection_string: Optional[str] = None, pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, **kwargs: Any, ) -> PGVector: if ids is None: @@ -352,6 +450,7 @@ class PGVector(VectorStore): embedding_function=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, **kwargs, ) @@ -508,7 +607,117 @@ class PGVector(VectorStore): ] return docs - def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def] + def _handle_field_filter( + self, + field: str, + value: Any, + ) -> SQLColumnExpression: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + + Returns: + sqlalchemy expression + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError( + f"Invalid field name: {field}. Expected a valid identifier." + ) + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + native = COMPARISONS_TO_NATIVE[operator] + return func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + f"$.{field} {native} $value", + json.dumps({"value": filter_value}), + ) + elif operator == "$between": + # Use AND with two comparisons + low, high = filter_value + + lower_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + f"$.{field} >= $value", + json.dumps({"value": low}), + ) + upper_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + f"$.{field} <= $value", + json.dumps({"value": high}), + ) + return sqlalchemy.and_(lower_bound, upper_bound) + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + queried_field = self.EmbeddingStore.cmetadata[field].astext + + if operator in {"$in"}: + return queried_field.in_([str(val) for val in filter_value]) + elif operator in {"$nin"}: + return queried_field.nin_([str(val) for val in filter_value]) + elif operator in {"$like"}: + return queried_field.like(filter_value) + elif operator in {"$ilike"}: + return queried_field.ilike(filter_value) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] + """Deprecated functionality. + + This is for backwards compatibility with the JSON based schema for metadata. + It uses incorrect operator syntax (operators are not prefixed with $). + + This implementation is not efficient, and has bugs associated with + the way that it handles numeric filter clauses. + """ IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" @@ -568,6 +777,117 @@ class PGVector(VectorStore): return filter_by_metadata + def _create_filter_clause_json_deprecated( + self, filter: Any + ) -> List[SQLColumnExpression]: + """Convert filters from IR to SQL clauses. + + **DEPRECATED** This functionality will be deprecated in the future. + + It implements translation of filters for a schema that uses JSON + for metadata rather than the JSONB field which is more efficient + for querying. + """ + filter_clauses = [] + for key, value in filter.items(): + if isinstance(value, dict): + filter_by_metadata = self._create_filter_clause_deprecated(key, value) + + if filter_by_metadata is not None: + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( + value + ) + filter_clauses.append(filter_by_metadata) + return filter_clauses + + def _create_filter_clause(self, filters: Any) -> Any: + """Convert LangChain IR filter representation to matching SQLAlchemy clauses. + + At the top level, we still don't know if we're working with a field + or an operator for the keys. After we've determined that we can + call the appropriate logic to handle filter creation. + + Args: + filters: Dictionary of filters to apply to the query. + + Returns: + SQLAlchemy clause to apply to the query. + """ + if isinstance(filters, dict): + if len(filters) == 1: + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or"]: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + else: + # Then it's a field + return self._handle_field_filter(key, filters[key]) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == "$and": + and_ = [self._create_filter_clause(el) for el in value] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$or": + or_ = [self._create_filter_clause(el) for el in value] + if len(or_) > 1: + return sqlalchemy.or_(*or_) + elif len(or_) == 1: + return or_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError("Got an empty dictionary for filters.") + else: + raise ValueError( + f"Invalid type: Expected a dictionary but got type: {type(filters)}" + ) + def __query_collection( self, embedding: List[float], @@ -580,24 +900,16 @@ class PGVector(VectorStore): if not collection: raise ValueError("Collection not found") - filter_by = self.EmbeddingStore.collection_id == collection.uuid - - if filter is not None: - filter_clauses = [] - - for key, value in filter.items(): - if isinstance(value, dict): - filter_by_metadata = self._create_filter_clause(key, value) - - if filter_by_metadata is not None: - filter_clauses.append(filter_by_metadata) - else: - filter_by_metadata = self.EmbeddingStore.cmetadata[ - key - ].astext == str(value) - filter_clauses.append(filter_by_metadata) - - filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) _type = self.EmbeddingStore @@ -606,7 +918,7 @@ class PGVector(VectorStore): self.EmbeddingStore, self.distance_strategy(embedding).label("distance"), # type: ignore ) - .filter(filter_by) + .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) .join( self.CollectionStore, @@ -615,6 +927,7 @@ class PGVector(VectorStore): .limit(k) .all() ) + return results def similarity_search_by_vector( @@ -649,6 +962,8 @@ class PGVector(VectorStore): distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, **kwargs: Any, ) -> PGVector: """ @@ -668,6 +983,7 @@ class PGVector(VectorStore): collection_name=collection_name, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, **kwargs, ) @@ -769,6 +1085,8 @@ class PGVector(VectorStore): distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, **kwargs: Any, ) -> PGVector: """ @@ -792,6 +1110,7 @@ class PGVector(VectorStore): metadatas=metadatas, ids=ids, collection_name=collection_name, + use_jsonb=use_jsonb, **kwargs, ) diff --git a/libs/community/tests/integration_tests/vectorstores/fixtures/filtering_test_cases.py b/libs/community/tests/integration_tests/vectorstores/fixtures/filtering_test_cases.py new file mode 100644 index 0000000000..de04ee7eb8 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/fixtures/filtering_test_cases.py @@ -0,0 +1,222 @@ +"""Module contains test cases for testing filtering of documents in vector stores. +""" +from langchain_core.documents import Document + +metadatas = [ + { + "name": "adam", + "date": "2021-01-01", + "count": 1, + "is_active": True, + "tags": ["a", "b"], + "location": [1.0, 2.0], + "info": {"address": "123 main st", "phone": "123-456-7890"}, + "id": 1, + "height": 10.0, # Float column + "happiness": 0.9, # Float column + "sadness": 0.1, # Float column + }, + { + "name": "bob", + "date": "2021-01-02", + "count": 2, + "is_active": False, + "tags": ["b", "c"], + "location": [2.0, 3.0], + "info": {"address": "456 main st", "phone": "123-456-7890"}, + "id": 2, + "height": 5.7, # Float column + "happiness": 0.8, # Float column + "sadness": 0.1, # Float column + }, + { + "name": "jane", + "date": "2021-01-01", + "count": 3, + "is_active": True, + "tags": ["b", "d"], + "location": [3.0, 4.0], + "info": {"address": "789 main st", "phone": "123-456-7890"}, + "id": 3, + "height": 2.4, # Float column + "happiness": None, + # Sadness missing intentionally + }, +] +texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas] + +DOCUMENTS = [ + Document(page_content=text, metadata=metadata) + for text, metadata in zip(texts, metadatas) +] + + +TYPE_1_FILTERING_TEST_CASES = [ + # These tests only involve equality checks + ( + {"id": 1}, + [1], + ), + # String field + ( + # check name + {"name": "adam"}, + [1], + ), + # Boolean fields + ( + {"is_active": True}, + [1, 3], + ), + ( + {"is_active": False}, + [2], + ), + # And semantics for top level filtering + ( + {"id": 1, "is_active": True}, + [1], + ), + ( + {"id": 1, "is_active": False}, + [], + ), +] + +TYPE_2_FILTERING_TEST_CASES = [ + # These involve equality checks and other operators + # like $ne, $gt, $gte, $lt, $lte, $not + ( + {"id": 1}, + [1], + ), + ( + {"id": {"$ne": 1}}, + [2, 3], + ), + ( + {"id": {"$gt": 1}}, + [2, 3], + ), + ( + {"id": {"$gte": 1}}, + [1, 2, 3], + ), + ( + {"id": {"$lt": 1}}, + [], + ), + ( + {"id": {"$lte": 1}}, + [1], + ), + # Repeat all the same tests with name (string column) + ( + {"name": "adam"}, + [1], + ), + ( + {"name": "bob"}, + [2], + ), + ( + {"name": {"$eq": "adam"}}, + [1], + ), + ( + {"name": {"$ne": "adam"}}, + [2, 3], + ), + # And also gt, gte, lt, lte relying on lexicographical ordering + ( + {"name": {"$gt": "jane"}}, + [], + ), + ( + {"name": {"$gte": "jane"}}, + [3], + ), + ( + {"name": {"$lt": "jane"}}, + [1, 2], + ), + ( + {"name": {"$lte": "jane"}}, + [1, 2, 3], + ), + ( + {"is_active": {"$eq": True}}, + [1, 3], + ), + ( + {"is_active": {"$ne": True}}, + [2], + ), + # Test float column. + ( + {"height": {"$gt": 5.0}}, + [1, 2], + ), + ( + {"height": {"$gte": 5.0}}, + [1, 2], + ), + ( + {"height": {"$lt": 5.0}}, + [3], + ), + ( + {"height": {"$lte": 5.8}}, + [2, 3], + ), +] + +TYPE_3_FILTERING_TEST_CASES = [ + # These involve usage of AND and OR operators + ( + {"$or": [{"id": 1}, {"id": 2}]}, + [1, 2], + ), + ( + {"$or": [{"id": 1}, {"name": "bob"}]}, + [1, 2], + ), + ( + {"$and": [{"id": 1}, {"id": 2}]}, + [], + ), + ( + {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, + [1, 2, 3], + ), +] + +TYPE_4_FILTERING_TEST_CASES = [ + # These involve special operators like $in, $nin, $between + # Test between + ( + {"id": {"$between": (1, 2)}}, + [1, 2], + ), + ( + {"id": {"$between": (1, 1)}}, + [1], + ), + ( + {"name": {"$in": ["adam", "bob"]}}, + [1, 2], + ), +] + +TYPE_5_FILTERING_TEST_CASES = [ + # These involve special operators like $like, $ilike that + # may be specified to certain databases. + ( + {"name": {"$like": "a%"}}, + [1], + ), + ( + {"name": {"$like": "%a%"}}, # adam and jane + [1, 3], + ), +] diff --git a/libs/community/tests/integration_tests/vectorstores/test_pgvector.py b/libs/community/tests/integration_tests/vectorstores/test_pgvector.py index 742546ff31..11c3fca8ac 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_pgvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_pgvector.py @@ -1,13 +1,26 @@ """Test PGVector functionality.""" import os -from typing import List +from typing import Any, Dict, Generator, List, Type, Union +import pytest import sqlalchemy from langchain_core.documents import Document +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session -from langchain_community.vectorstores.pgvector import PGVector +from langchain_community.vectorstores.pgvector import ( + SUPPORTED_OPERATORS, + PGVector, +) from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( + DOCUMENTS, + TYPE_1_FILTERING_TEST_CASES, + TYPE_2_FILTERING_TEST_CASES, + TYPE_3_FILTERING_TEST_CASES, + TYPE_4_FILTERING_TEST_CASES, + TYPE_5_FILTERING_TEST_CASES, +) # The connection string matches the default settings in the docker-compose file # located in the root of the repository: [root]/docker/docker-compose.yml @@ -42,7 +55,7 @@ class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] -def test_pgvector() -> None: +def test_pgvector(pgvector: PGVector) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = PGVector.from_texts( @@ -375,3 +388,255 @@ def test_pgvector_with_custom_engine_args() -> None: ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] + + +# We should reuse this test-case across other integrations +# Add database fixture using pytest +@pytest.fixture +def pgvector() -> Generator[PGVector, None, None]: + """Create a PGVector instance.""" + store = PGVector.from_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + # Do clean up + finally: + store.drop_tables() + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_1( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_2( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_3( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_4( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_5( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize( + "invalid_filter", + [ + ["hello"], + { + "id": 2, + "$name": "foo", + }, + {"$or": {}}, + {"$and": {}}, + {"$between": {}}, + {"$eq": {}}, + ], +) +def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None: + """Verify that invalid filters raise an error.""" + with pytest.raises(ValueError): + pgvector._create_filter_clause(invalid_filter) + + +@pytest.mark.parametrize( + "filter,compiled", + [ + ({"id 'evil code'": 2}, ValueError), + ( + {"id": "'evil code' == 2"}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, " + "'$.id == $value', " + "'{\"value\": \"''evil code'' == 2\"}')" + ), + ), + ( + {"name": 'a"b'}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, " + "'$.name == $value', " + '\'{"value": "a\\\\"b"}\')' + ), + ), + ], +) +def test_evil_code( + pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str] +) -> None: + """Test evil code.""" + if isinstance(compiled, str): + clause = pgvector._create_filter_clause(filter) + compiled_stmt = str( + clause.compile( + dialect=postgresql.dialect(), + compile_kwargs={ + # This substitutes the parameters with their actual values + "literal_binds": True + }, + ) + ) + assert compiled_stmt == compiled + else: + with pytest.raises(compiled): + pgvector._create_filter_clause(filter) + + +@pytest.mark.parametrize( + "filter,compiled", + [ + ( + {"id": 2}, + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', " + "'{\"value\": 2}')", + ), + ( + {"id": {"$eq": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"name": "foo"}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, " + "'$.name == $value', " + '\'{"value": "foo"}\')' + ), + ), + ( + {"id": {"$ne": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"id": {"$gt": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"id": {"$gte": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"id": {"$lt": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"id": {"$lte": 2}}, + ( + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', " + "'{\"value\": 2}')" + ), + ), + ( + {"name": {"$ilike": "foo"}}, + "langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'", + ), + ( + {"name": {"$like": "foo"}}, + "langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'", + ), + ( + {"$or": [{"id": 1}, {"id": 2}]}, + # Please note that this might not be super optimized + # Another way to phrase the query is as + # langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)' + "jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', " + "'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, " + "'$.id == $value', '{\"value\": 2}')", + ), + ], +) +def test_pgvector_query_compilation( + pgvector: PGVector, filter: Any, compiled: str +) -> None: + """Test translation from IR to SQL""" + clause = pgvector._create_filter_clause(filter) + compiled_stmt = str( + clause.compile( + dialect=postgresql.dialect(), + compile_kwargs={ + # This substitutes the parameters with their actual values + "literal_binds": True + }, + ) + ) + assert compiled_stmt == compiled + + +def test_validate_operators() -> None: + """Verify that all operators have been categorized.""" + assert sorted(SUPPORTED_OPERATORS) == [ + "$and", + "$between", + "$eq", + "$gt", + "$gte", + "$ilike", + "$in", + "$like", + "$lt", + "$lte", + "$ne", + "$nin", + "$or", + ]