"""Test PGVector functionality.""" import os from typing import List from sqlalchemy.orm import Session from langchain.docstore.document import Document from langchain.vectorstores.pgvector import PGVector from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings CONNECTION_STRING = PGVector.connection_string_from_db_params( driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"), host=os.environ.get("TEST_PGVECTOR_HOST", "localhost"), port=int(os.environ.get("TEST_PGVECTOR_PORT", "5432")), database=os.environ.get("TEST_PGVECTOR_DATABASE", "postgres"), user=os.environ.get("TEST_PGVECTOR_USER", "postgres"), password=os.environ.get("TEST_PGVECTOR_PASSWORD", "postgres"), ) ADA_TOKEN_COUNT = 1536 class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): """Fake embeddings functionality for testing.""" def embed_documents(self, texts: List[str]) -> List[List[float]]: """Return simple embeddings.""" return [ [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) ] def embed_query(self, text: str) -> List[float]: """Return simple embeddings.""" return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] def test_pgvector() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] def test_pgvector_embeddings() -> None: """Test end to end construction with embeddings and search.""" texts = ["foo", "bar", "baz"] text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) text_embedding_pairs = list(zip(texts, text_embeddings)) docsearch = PGVector.from_embeddings( text_embeddings=text_embedding_pairs, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] def test_pgvector_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo", metadata={"page": "0"})] def test_pgvector_with_metadatas_with_scores() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] def test_pgvector_with_filter_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection_filter", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] def test_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection_filter", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) assert output == [ (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) ] def test_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection_filter", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] def test_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" pgvector = PGVector( collection_name="test_collection", collection_metadata={"foo": "bar"}, embedding_function=FakeEmbeddingsWithAdaDimension(), connection_string=CONNECTION_STRING, pre_delete_collection=True, ) session = Session(pgvector.connect()) collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" else: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} def test_pgvector_with_filter_in_set() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection_filter", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_score( "foo", k=2, filter={"page": {"IN": ["0", "2"]}} ) assert output == [ (Document(page_content="foo", metadata={"page": "0"}), 0.0), (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), ] def test_pgvector_relevance_score() -> None: """Test to make sure the relevance score is scaled to 0-1.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) output = docsearch.similarity_search_with_relevance_scores("foo", k=3) assert output == [ (Document(page_content="foo", metadata={"page": "0"}), 1.0), (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), ] def test_pgvector_retriever_search_threshold() -> None: """Test using retriever for searching with threshold.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, ) retriever = docsearch.as_retriever( search_type="similarity_score_threshold", search_kwargs={"k": 3, "score_threshold": 0.999}, ) output = retriever.get_relevant_documents("summer") assert output == [ Document(page_content="foo", metadata={"page": "0"}), Document(page_content="bar", metadata={"page": "1"}), ] def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: """Test searching with threshold and custom normalization function""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = PGVector.from_texts( texts=texts, collection_name="test_collection", embedding=FakeEmbeddingsWithAdaDimension(), metadatas=metadatas, connection_string=CONNECTION_STRING, pre_delete_collection=True, relevance_score_fn=lambda d: d * 0, ) retriever = docsearch.as_retriever( search_type="similarity_score_threshold", search_kwargs={"k": 3, "score_threshold": 0.5}, ) output = retriever.get_relevant_documents("foo") assert output == []