refactor: enable connection pool usage in PGVector (#11514)

- **Description:** `PGVector` refactored to use connection pool.
  - **Issue:** #11433,
  - **Tag maintainer:** @hwchase17 @eyurtsev,

---------

Co-authored-by: Diego Rani Mazine <diego.mazine@mercadolivre.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
pull/15293/head
Diego Rani Mazine 6 months ago committed by GitHub
parent 507c195a4b
commit ec72225265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -203,8 +203,7 @@ class PGVector(VectorStore):
self.logger = logger or logging.getLogger(__name__)
self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {}
# Create a connection if not provided, otherwise use the provided connection
self._conn = connection if connection else self.connect()
self._bind = connection if connection else self._create_engine()
self.__post_init__()
def __post_init__(
@ -220,21 +219,19 @@ class PGVector(VectorStore):
self.create_collection()
def __del__(self) -> None:
if self._conn:
self._conn.close()
if isinstance(self._bind, sqlalchemy.engine.Connection):
self._bind.close()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
conn = engine.connect()
return conn
def _create_engine(self) -> sqlalchemy.engine.Engine:
return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args)
def create_vector_extension(self) -> None:
try:
with Session(self._conn) as session:
with Session(self._bind) as session:
# The advisor lock fixes issue arising from concurrent
# creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933
@ -252,24 +249,24 @@ class PGVector(VectorStore):
raise Exception(f"Failed to create vector extension: {e}") from e
def create_tables_if_not_exists(self) -> None:
with self._conn.begin():
Base.metadata.create_all(self._conn)
with Session(self._bind) as session, session.begin():
Base.metadata.create_all(session.get_bind())
def drop_tables(self) -> None:
with self._conn.begin():
Base.metadata.drop_all(self._conn)
with Session(self._bind) as session, session.begin():
Base.metadata.drop_all(session.get_bind())
def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._conn) as session:
with Session(self._bind) as session:
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
with Session(self._conn) as session:
with Session(self._bind) as session:
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
@ -280,7 +277,7 @@ class PGVector(VectorStore):
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string."""
yield Session(self._conn)
yield Session(self._bind)
def delete(
self,
@ -292,7 +289,7 @@ class PGVector(VectorStore):
Args:
ids: List of ids to delete.
"""
with Session(self._conn) as session:
with Session(self._bind) as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
@ -366,7 +363,7 @@ class PGVector(VectorStore):
if not metadatas:
metadatas = [{} for _ in texts]
with Session(self._conn) as session:
with Session(self._bind) as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
@ -496,7 +493,7 @@ class PGVector(VectorStore):
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
with Session(self._conn) as session:
with Session(self._bind) as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")

@ -0,0 +1,17 @@
version: "3.8"
services:
pgvector:
image: ankane/pgvector:latest
environment:
POSTGRES_DB: ${PGVECTOR_DB:-postgres}
POSTGRES_USER: ${PGVECTOR_USER:-postgres}
POSTGRES_PASSWORD: ${PGVECTOR_PASSWORD:-postgres}
ports:
- ${PGVECTOR_PORT:-5432}:5432
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:5432"]
interval: 10s
timeout: 5s
retries: 5

@ -2,6 +2,7 @@
import os
from typing import List
import sqlalchemy
from langchain_core.documents import Document
from sqlalchemy.orm import Session
@ -155,7 +156,7 @@ def test_pgvector_collection_with_metadata() -> None:
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
session = Session(pgvector.connect())
session = Session(pgvector._create_engine())
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
@ -327,3 +328,43 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None:
)
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
assert output == [(Document(page_content="foo"), 0.0)]
def test_pgvector_with_custom_connection() -> None:
"""Test construction using a custom connection."""
texts = ["foo", "bar", "baz"]
engine = sqlalchemy.create_engine(CONNECTION_STRING)
with engine.connect() as connection:
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
connection=connection,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_pgvector_with_custom_engine_args() -> None:
"""Test construction using custom engine arguments."""
texts = ["foo", "bar", "baz"]
engine_args = {
"pool_size": 5,
"max_overflow": 10,
"pool_recycle": -1,
"pool_use_lifo": False,
"pool_pre_ping": False,
"pool_timeout": 30,
}
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
engine_args=engine_args,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]

@ -0,0 +1,73 @@
"""Test PGVector functionality."""
from unittest import mock
from unittest.mock import Mock
import pytest
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.vectorstores import pgvector
_CONNECTION_STRING = pgvector.PGVector.connection_string_from_db_params(
driver="psycopg2",
host="localhost",
port=5432,
database="postgres",
user="postgres",
password="postgres",
)
_EMBEDDING_FUNCTION = FakeEmbeddings(size=1536)
@pytest.mark.requires("pgvector")
@mock.patch("sqlalchemy.create_engine")
def test_given_a_connection_is_provided_then_no_engine_should_be_created(
create_engine: Mock,
) -> None:
"""When a connection is provided then no engine should be created."""
pgvector.PGVector(
connection_string=_CONNECTION_STRING,
embedding_function=_EMBEDDING_FUNCTION,
connection=mock.MagicMock(),
)
create_engine.assert_not_called()
@pytest.mark.requires("pgvector")
@mock.patch("sqlalchemy.create_engine")
def test_given_no_connection_or_engine_args_provided_default_engine_should_be_used(
create_engine: Mock,
) -> None:
"""When no connection or engine arguments are provided then the default configuration must be used.""" # noqa: E501
pgvector.PGVector(
connection_string=_CONNECTION_STRING,
embedding_function=_EMBEDDING_FUNCTION,
)
create_engine.assert_called_with(
url=_CONNECTION_STRING,
)
@pytest.mark.requires("pgvector")
@mock.patch("sqlalchemy.create_engine")
def test_given_engine_args_are_provided_then_they_should_be_used(
create_engine: Mock,
) -> None:
"""When engine arguments are provided then they must be used to create the underlying engine.""" # noqa: E501
engine_args = {
"pool_size": 5,
"max_overflow": 10,
"pool_recycle": -1,
"pool_use_lifo": False,
"pool_pre_ping": False,
"pool_timeout": 30,
}
pgvector.PGVector(
connection_string=_CONNECTION_STRING,
embedding_function=_EMBEDDING_FUNCTION,
engine_args=engine_args,
)
create_engine.assert_called_with(
url=_CONNECTION_STRING,
**engine_args,
)
Loading…
Cancel
Save