diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index d5c2294b10..b7b3e529c6 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -203,8 +203,7 @@ class PGVector(VectorStore): self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} - # Create a connection if not provided, otherwise use the provided connection - self._conn = connection if connection else self.connect() + self._bind = connection if connection else self._create_engine() self.__post_init__() def __post_init__( @@ -220,21 +219,19 @@ class PGVector(VectorStore): self.create_collection() def __del__(self) -> None: - if self._conn: - self._conn.close() + if isinstance(self._bind, sqlalchemy.engine.Connection): + self._bind.close() @property def embeddings(self) -> Embeddings: return self.embedding_function - def connect(self) -> sqlalchemy.engine.Connection: - engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args) - conn = engine.connect() - return conn + def _create_engine(self) -> sqlalchemy.engine.Engine: + return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args) def create_vector_extension(self) -> None: try: - with Session(self._conn) as session: + with Session(self._bind) as session: # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -252,24 +249,24 @@ class PGVector(VectorStore): raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with self._conn.begin(): - Base.metadata.create_all(self._conn) + with Session(self._bind) as session, session.begin(): + Base.metadata.create_all(session.get_bind()) def drop_tables(self) -> None: - with self._conn.begin(): - Base.metadata.drop_all(self._conn) + with Session(self._bind) as session, session.begin(): + Base.metadata.drop_all(session.get_bind()) def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with Session(self._conn) as session: + with Session(self._bind) as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with Session(self._conn) as session: + with Session(self._bind) as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -280,7 +277,7 @@ class PGVector(VectorStore): @contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """Create a context manager for the session, bind to _conn string.""" - yield Session(self._conn) + yield Session(self._bind) def delete( self, @@ -292,7 +289,7 @@ class PGVector(VectorStore): Args: ids: List of ids to delete. """ - with Session(self._conn) as session: + with Session(self._bind) as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -366,7 +363,7 @@ class PGVector(VectorStore): if not metadatas: metadatas = [{} for _ in texts] - with Session(self._conn) as session: + with Session(self._bind) as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -496,7 +493,7 @@ class PGVector(VectorStore): filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" - with Session(self._conn) as session: + with Session(self._bind) as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") diff --git a/libs/community/tests/integration_tests/vectorstores/docker-compose/pgvector.yml b/libs/community/tests/integration_tests/vectorstores/docker-compose/pgvector.yml new file mode 100644 index 0000000000..5dd0b46036 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/docker-compose/pgvector.yml @@ -0,0 +1,17 @@ +version: "3.8" + +services: + pgvector: + image: ankane/pgvector:latest + environment: + POSTGRES_DB: ${PGVECTOR_DB:-postgres} + POSTGRES_USER: ${PGVECTOR_USER:-postgres} + POSTGRES_PASSWORD: ${PGVECTOR_PASSWORD:-postgres} + ports: + - ${PGVECTOR_PORT:-5432}:5432 + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:5432"] + interval: 10s + timeout: 5s + retries: 5 \ No newline at end of file diff --git a/libs/community/tests/integration_tests/vectorstores/test_pgvector.py b/libs/community/tests/integration_tests/vectorstores/test_pgvector.py index 0a9d5265fa..db0b9ee7a0 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_pgvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_pgvector.py @@ -2,6 +2,7 @@ import os from typing import List +import sqlalchemy from langchain_core.documents import Document from sqlalchemy.orm import Session @@ -155,7 +156,7 @@ def test_pgvector_collection_with_metadata() -> None: connection_string=CONNECTION_STRING, pre_delete_collection=True, ) - session = Session(pgvector.connect()) + session = Session(pgvector._create_engine()) collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -327,3 +328,43 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None: ) output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) assert output == [(Document(page_content="foo"), 0.0)] + + +def test_pgvector_with_custom_connection() -> None: + """Test construction using a custom connection.""" + texts = ["foo", "bar", "baz"] + engine = sqlalchemy.create_engine(CONNECTION_STRING) + with engine.connect() as connection: + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + connection=connection, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_pgvector_with_custom_engine_args() -> None: + """Test construction using custom engine arguments.""" + texts = ["foo", "bar", "baz"] + engine_args = { + "pool_size": 5, + "max_overflow": 10, + "pool_recycle": -1, + "pool_use_lifo": False, + "pool_pre_ping": False, + "pool_timeout": 30, + } + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + engine_args=engine_args, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] diff --git a/libs/community/tests/unit_tests/vectorstores/test_pgvector.py b/libs/community/tests/unit_tests/vectorstores/test_pgvector.py new file mode 100644 index 0000000000..26e57abc98 --- /dev/null +++ b/libs/community/tests/unit_tests/vectorstores/test_pgvector.py @@ -0,0 +1,73 @@ +"""Test PGVector functionality.""" +from unittest import mock +from unittest.mock import Mock + +import pytest + +from langchain_community.embeddings import FakeEmbeddings +from langchain_community.vectorstores import pgvector + +_CONNECTION_STRING = pgvector.PGVector.connection_string_from_db_params( + driver="psycopg2", + host="localhost", + port=5432, + database="postgres", + user="postgres", + password="postgres", +) + +_EMBEDDING_FUNCTION = FakeEmbeddings(size=1536) + + +@pytest.mark.requires("pgvector") +@mock.patch("sqlalchemy.create_engine") +def test_given_a_connection_is_provided_then_no_engine_should_be_created( + create_engine: Mock, +) -> None: + """When a connection is provided then no engine should be created.""" + pgvector.PGVector( + connection_string=_CONNECTION_STRING, + embedding_function=_EMBEDDING_FUNCTION, + connection=mock.MagicMock(), + ) + create_engine.assert_not_called() + + +@pytest.mark.requires("pgvector") +@mock.patch("sqlalchemy.create_engine") +def test_given_no_connection_or_engine_args_provided_default_engine_should_be_used( + create_engine: Mock, +) -> None: + """When no connection or engine arguments are provided then the default configuration must be used.""" # noqa: E501 + pgvector.PGVector( + connection_string=_CONNECTION_STRING, + embedding_function=_EMBEDDING_FUNCTION, + ) + create_engine.assert_called_with( + url=_CONNECTION_STRING, + ) + + +@pytest.mark.requires("pgvector") +@mock.patch("sqlalchemy.create_engine") +def test_given_engine_args_are_provided_then_they_should_be_used( + create_engine: Mock, +) -> None: + """When engine arguments are provided then they must be used to create the underlying engine.""" # noqa: E501 + engine_args = { + "pool_size": 5, + "max_overflow": 10, + "pool_recycle": -1, + "pool_use_lifo": False, + "pool_pre_ping": False, + "pool_timeout": 30, + } + pgvector.PGVector( + connection_string=_CONNECTION_STRING, + embedding_function=_EMBEDDING_FUNCTION, + engine_args=engine_args, + ) + create_engine.assert_called_with( + url=_CONNECTION_STRING, + **engine_args, + )