diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index 3418f6e454..bcf05eacf9 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -116,6 +116,7 @@ class PGVector(VectorStore): logger: Optional[logging.Logger] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, *, + connection: Optional[sqlalchemy.engine.Connection] = None, engine_args: Optional[dict[str, Any]] = None, ) -> None: self.connection_string = connection_string @@ -127,15 +128,13 @@ class PGVector(VectorStore): self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} - self.__post_init__() + # Create a connection if not provided, otherwise use the provided connection + self._conn = connection if connection else self.connect() def __post_init__( self, ) -> None: - """ - Initialize the store. - """ - self._conn = self.connect() + """Initialize the store.""" self.create_vector_extension() from langchain.vectorstores._pgvector_data_models import ( CollectionStore,