From 66f8cb015d863757ec306c84472f86e6c8df1f00 Mon Sep 17 00:00:00 2001 From: aubin_mzt <72651332+aubinmazet@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:43:44 +0200 Subject: [PATCH] Add connection args for pgvector vector store (#11930) - **Description:** sqlalchemy create_engine() does not take into account connect_args which are mandatory for managed PGSQL instances on cloud providers (ssl_context for example). Also re-enabled create_vector_extension at post_init for using pgvector class seamlessly - **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17. --------- Co-authored-by: Sami Bargaoui Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/vectorstores/pgvector.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index 30e1c2d470..3418f6e454 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -84,6 +84,7 @@ class PGVector(VectorStore): distance_strategy: The distance strategy to use. (default: COSINE) pre_delete_collection: If True, will delete the collection if it exists. (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. Example: .. code-block:: python @@ -114,6 +115,8 @@ class PGVector(VectorStore): pre_delete_collection: bool = False, logger: Optional[logging.Logger] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, + *, + engine_args: Optional[dict[str, Any]] = None, ) -> None: self.connection_string = connection_string self.embedding_function = embedding_function @@ -123,6 +126,7 @@ class PGVector(VectorStore): self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn + self.engine_args = engine_args or {} self.__post_init__() def __post_init__( @@ -132,7 +136,7 @@ class PGVector(VectorStore): Initialize the store. """ self._conn = self.connect() - # self.create_vector_extension() + self.create_vector_extension() from langchain.vectorstores._pgvector_data_models import ( CollectionStore, EmbeddingStore, @@ -148,7 +152,7 @@ class PGVector(VectorStore): return self.embedding_function def connect(self) -> sqlalchemy.engine.Connection: - engine = sqlalchemy.create_engine(self.connection_string) + engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args) conn = engine.connect() return conn @@ -159,7 +163,7 @@ class PGVector(VectorStore): session.execute(statement) session.commit() except Exception as e: - self.logger.exception(e) + raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: with self._conn.begin():