Add connection args for pgvector vector store (#11930)

- **Description:** sqlalchemy create_engine() does not take into account
connect_args which are mandatory for managed PGSQL instances on cloud
providers (ssl_context for example).
Also re-enabled create_vector_extension at post_init for using pgvector
class seamlessly
- **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17.

---------

Co-authored-by: Sami Bargaoui <bargaoui.sam@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/12578/head
aubin_mzt 12 months ago committed by GitHub
parent 4d6243fa87
commit 66f8cb015d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -84,6 +84,7 @@ class PGVector(VectorStore):
distance_strategy: The distance strategy to use. (default: COSINE)
pre_delete_collection: If True, will delete the collection if it exists.
(default: False). Useful for testing.
engine_args: SQLAlchemy's create engine arguments.
Example:
.. code-block:: python
@ -114,6 +115,8 @@ class PGVector(VectorStore):
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
*,
engine_args: Optional[dict[str, Any]] = None,
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
@ -123,6 +126,7 @@ class PGVector(VectorStore):
self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__)
self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {}
self.__post_init__()
def __post_init__(
@ -132,7 +136,7 @@ class PGVector(VectorStore):
Initialize the store.
"""
self._conn = self.connect()
# self.create_vector_extension()
self.create_vector_extension()
from langchain.vectorstores._pgvector_data_models import (
CollectionStore,
EmbeddingStore,
@ -148,7 +152,7 @@ class PGVector(VectorStore):
return self.embedding_function
def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string)
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
conn = engine.connect()
return conn
@ -159,7 +163,7 @@ class PGVector(VectorStore):
session.execute(statement)
session.commit()
except Exception as e:
self.logger.exception(e)
raise Exception(f"Failed to create vector extension: {e}") from e
def create_tables_if_not_exists(self) -> None:
with self._conn.begin():

Loading…
Cancel
Save