From 88a8f59aa76736cde41daa265ac5220d578aa77b Mon Sep 17 00:00:00 2001 From: Richy Wang Date: Sat, 22 Apr 2023 23:25:41 +0800 Subject: [PATCH] Add a full PostgresSQL syntax database 'AnalyticDB' as vector store. (#3135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hi there! I'm excited to open this PR to add support for using a fully Postgres syntax compatible database 'AnalyticDB' as a vector. As AnalyticDB has been proved can be used with AutoGPT, ChatGPT-Retrieve-Plugin, and LLama-Index, I think it is also good for you. AnalyticDB is a distributed Alibaba Cloud-Native vector database. It works better when data comes to large scale. The PR includes: - [x] A new memory: AnalyticDBVector - [x] A suite of integration tests verifies the AnalyticDB integration I have read your [contributing guidelines](https://github.com/hwchase17/langchain/blob/72b7d76d79b0e187426787616d96257b64292119/.github/CONTRIBUTING.md). And I have passed the tests below - [x] make format - [x] make lint - [x] make coverage - [x] make test --- docs/ecosystem/analyticdb.md | 15 + .../vectorstores/examples/analyticdb.ipynb | 162 +++++++ langchain/vectorstores/__init__.py | 2 + langchain/vectorstores/analyticdb.py | 432 ++++++++++++++++++ .../vectorstores/test_analyticdb.py | 148 ++++++ 5 files changed, 759 insertions(+) create mode 100644 docs/ecosystem/analyticdb.md create mode 100644 docs/modules/indexes/vectorstores/examples/analyticdb.ipynb create mode 100644 langchain/vectorstores/analyticdb.py create mode 100644 tests/integration_tests/vectorstores/test_analyticdb.py diff --git a/docs/ecosystem/analyticdb.md b/docs/ecosystem/analyticdb.md new file mode 100644 index 00000000..59cf8832 --- /dev/null +++ b/docs/ecosystem/analyticdb.md @@ -0,0 +1,15 @@ +# AnalyticDB + +This page covers how to use the AnalyticDB ecosystem within LangChain. + +### VectorStore + +There exists a wrapper around AnalyticDB, allowing you to use it as a vectorstore, +whether for semantic search or example selection. + +To import this vectorstore: +```python +from langchain.vectorstores import AnalyticDB +``` + +For a more detailed walkthrough of the AnalyticDB wrapper, see [this notebook](../modules/indexes/vectorstores/examples/analyticdb.ipynb) diff --git a/docs/modules/indexes/vectorstores/examples/analyticdb.ipynb b/docs/modules/indexes/vectorstores/examples/analyticdb.ipynb new file mode 100644 index 00000000..c5178c68 --- /dev/null +++ b/docs/modules/indexes/vectorstores/examples/analyticdb.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AnalyticDB\n", + "\n", + "This notebook shows how to use functionality related to the AnalyticDB vector database.\n", + "To run, you should have an [AnalyticDB](https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/latest/product-introduction-overview) instance up and running:\n", + "- Using [AnalyticDB Cloud Vector Database](https://www.alibabacloud.com/product/hybriddb-postgresql). Click here to fast deploy it." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import AnalyticDB" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Split documents and get embeddings by call OpenAI API" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "loader = TextLoader('../../../state_of_the_union.txt')\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embeddings = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Connect to AnalyticDB by setting related ENVIRONMENTS.\n", + "```\n", + "export PG_HOST={your_analyticdb_hostname}\n", + "export PG_PORT={your_analyticdb_port} # Optional, default is 5432\n", + "export PG_DATABASE={your_database} # Optional, default is postgres\n", + "export PG_USER={database_username}\n", + "export PG_PASSWORD={database_password}\n", + "```\n", + "\n", + "Then store your embeddings and documents into AnalyticDB" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "connection_string = AnalyticDB.connection_string_from_db_params(\n", + " driver=os.environ.get(\"PG_DRIVER\", \"psycopg2cffi\"),\n", + " host=os.environ.get(\"PG_HOST\", \"localhost\"),\n", + " port=int(os.environ.get(\"PG_PORT\", \"5432\")),\n", + " database=os.environ.get(\"PG_DATABASE\", \"postgres\"),\n", + " user=os.environ.get(\"PG_USER\", \"postgres\"),\n", + " password=os.environ.get(\"PG_PASSWORD\", \"postgres\"),\n", + ")\n", + "\n", + "vector_db = AnalyticDB.from_documents(\n", + " docs,\n", + " embeddings,\n", + " connection_string= connection_string,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Query and retrieve data" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = vector_db.similarity_search(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n" + ] + } + ], + "source": [ + "print(docs[0].page_content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/langchain/vectorstores/__init__.py b/langchain/vectorstores/__init__.py index 485dc712..30743967 100644 --- a/langchain/vectorstores/__init__.py +++ b/langchain/vectorstores/__init__.py @@ -1,4 +1,5 @@ """Wrappers on top of vector stores.""" +from langchain.vectorstores.analyticdb import AnalyticDB from langchain.vectorstores.annoy import Annoy from langchain.vectorstores.atlas import AtlasDB from langchain.vectorstores.base import VectorStore @@ -27,4 +28,5 @@ __all__ = [ "DeepLake", "Annoy", "SupabaseVectorStore", + "AnalyticDB", ] diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py new file mode 100644 index 00000000..6ed8e5b0 --- /dev/null +++ b/langchain/vectorstores/analyticdb.py @@ -0,0 +1,432 @@ +"""VectorStore wrapper around a Postgres/PGVector database.""" +from __future__ import annotations + +import logging +import uuid +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import sqlalchemy +from sqlalchemy import REAL, Index +from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID +from sqlalchemy.orm import Mapped, Session, declarative_base, relationship +from sqlalchemy.sql.expression import func + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.base import VectorStore + +Base = declarative_base() # type: Any + + +ADA_TOKEN_COUNT = 1536 +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +class BaseModel(Base): + __abstract__ = True + uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + + +class CollectionStore(BaseModel): + __tablename__ = "langchain_pg_collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata = sqlalchemy.Column(JSON) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: + return session.query(cls).filter(cls.name == name).first() + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + +class EmbeddingStore(BaseModel): + __tablename__ = "langchain_pg_embedding" + + collection_id: Mapped[UUID] = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding = sqlalchemy.Column(ARRAY(REAL)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSON, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + # The following line creates an index named 'langchain_pg_embedding_vector_idx' + langchain_pg_embedding_vector_idx = Index( + "langchain_pg_embedding_vector_idx", + embedding, + postgresql_using="ann", + postgresql_with={ + "distancemeasure": "L2", + "dim": 1536, + "pq_segments": 64, + "hnsw_m": 100, + "pq_centers": 2048, + }, + ) + + +class QueryResult: + EmbeddingStore: EmbeddingStore + distance: float + + +class AnalyticDB(VectorStore): + """ + VectorStore implementation using AnalyticDB. + AnalyticDB is a distributed full PostgresSQL syntax cloud-native database. + - `connection_string` is a postgres connection string. + - `embedding_function` any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + - `collection_name` is the name of the collection to use. (default: langchain) + - NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + - `pre_delete_collection` if True, will delete the collection if it exists. + (default: False) + - Useful for testing. + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + ) -> None: + self.connection_string = connection_string + self.embedding_function = embedding_function + self.collection_name = collection_name + self.collection_metadata = collection_metadata + self.pre_delete_collection = pre_delete_collection + self.logger = logger or logging.getLogger(__name__) + self.__post_init__() + + def __post_init__( + self, + ) -> None: + """ + Initialize the store. + """ + self._conn = self.connect() + self.create_tables_if_not_exists() + self.create_collection() + + def connect(self) -> sqlalchemy.engine.Connection: + engine = sqlalchemy.create_engine(self.connection_string) + conn = engine.connect() + return conn + + def create_tables_if_not_exists(self) -> None: + Base.metadata.create_all(self._conn) + + def drop_tables(self) -> None: + Base.metadata.drop_all(self._conn) + + def create_collection(self) -> None: + if self.pre_delete_collection: + self.delete_collection() + with Session(self._conn) as session: + CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + def delete_collection(self) -> None: + self.logger.debug("Trying to delete collection") + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + self.logger.error("Collection not found") + return + session.delete(collection) + session.commit() + + def get_collection(self, session: Session) -> Optional["CollectionStore"]: + return CollectionStore.get_by_name(session, self.collection_name) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + embeddings = self.embedding_function.embed_documents(list(texts)) + + if not metadatas: + metadatas = [{} for _ in texts] + + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + ) + collection.embeddings.append(embedding_store) + session.add(embedding_store) + session.commit() + + return ids + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with AnalyticDB with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding_function.embed_query(text=query) + return self.similarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query and score for each + """ + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = EmbeddingStore.collection_id == collection.uuid + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value) + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + + results: List[QueryResult] = ( + session.query( + EmbeddingStore, + func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"), + ) + .filter(filter_by) + .order_by(EmbeddingStore.embedding.op("<->")(embedding)) + .join( + CollectionStore, + EmbeddingStore.collection_id == CollectionStore.uuid, + ) + .limit(k) + .all() + ) + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result.distance if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query vector. + """ + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return [doc for doc, _ in docs_and_scores] + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> AnalyticDB: + """ + Return VectorStore initialized from texts and embeddings. + Postgres connection string is required + Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + pre_delete_collection=pre_delete_collection, + ) + + store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) + return store + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="PGVECTOR_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Postgres connection string is required" + "Either pass it as a parameter" + "or set the PGVECTOR_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def from_documents( + cls, + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> AnalyticDB: + """ + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + connection_string = cls.get_connection_string(kwargs) + + kwargs["connection_string"] = connection_string + + return cls.from_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + **kwargs, + ) + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}" diff --git a/tests/integration_tests/vectorstores/test_analyticdb.py b/tests/integration_tests/vectorstores/test_analyticdb.py new file mode 100644 index 00000000..d3bbe0e6 --- /dev/null +++ b/tests/integration_tests/vectorstores/test_analyticdb.py @@ -0,0 +1,148 @@ +"""Test PGVector functionality.""" +import os +from typing import List + +from sqlalchemy.orm import Session + +from langchain.docstore.document import Document +from langchain.vectorstores.analyticdb import AnalyticDB +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings + +CONNECTION_STRING = AnalyticDB.connection_string_from_db_params( + driver=os.environ.get("PG_DRIVER", "psycopg2cffi"), + host=os.environ.get("PG_HOST", "localhost"), + port=int(os.environ.get("PG_HOST", "5432")), + database=os.environ.get("PG_DATABASE", "postgres"), + user=os.environ.get("PG_USER", "postgres"), + password=os.environ.get("PG_PASSWORD", "postgres"), +) + + +ADA_TOKEN_COUNT = 1536 + + +class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] + + +def test_analyticdb() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_analyticdb_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +def test_analyticdb_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_analyticdb_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_analyticdb_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + print(output) + assert output == [(Document(page_content="baz", metadata={"page": "2"}), 4.0)] + + +def test_analyticdb_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + +def test_analyticdb_collection_with_metadata() -> None: + """Test end to end collection construction""" + pgvector = AnalyticDB( + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding_function=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = Session(pgvector.connect()) + collection = pgvector.get_collection(session) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"}