diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 2e53dd35..9c3d70d7 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import sqlalchemy from pgvector.sqlalchemy import Vector -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy.orm import Mapped, Session, declarative_base, relationship from langchain.docstore.document import Document @@ -29,7 +29,7 @@ class CollectionStore(BaseModel): __tablename__ = "langchain_pg_collection" name = sqlalchemy.Column(sqlalchemy.String) - cmetadata = sqlalchemy.Column(sqlalchemy.JSON) + cmetadata = sqlalchemy.Column(JSON) embeddings = relationship( "EmbeddingStore", @@ -57,7 +57,7 @@ class CollectionStore(BaseModel): if collection: return collection, created - collection = cls(name=name, metadata=cmetadata) + collection = cls(name=name, cmetadata=cmetadata) session.add(collection) session.commit() created = True @@ -121,6 +121,7 @@ class PGVector(VectorStore): connection_string: str, embedding_function: Embeddings, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, logger: Optional[logging.Logger] = None, @@ -128,6 +129,7 @@ class PGVector(VectorStore): self.connection_string = connection_string self.embedding_function = embedding_function self.collection_name = collection_name + self.collection_metadata = collection_metadata self.distance_strategy = distance_strategy self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) @@ -168,7 +170,9 @@ class PGVector(VectorStore): if self.pre_delete_collection: self.delete_collection() with Session(self._conn) as session: - CollectionStore.get_or_create(session, self.collection_name) + CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 34d3ae05..5479a1f5 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -2,6 +2,8 @@ import os from typing import List +from sqlalchemy.orm import Session + from langchain.docstore.document import Document from langchain.vectorstores.pgvector import PGVector from tests.integration_tests.vectorstores.fake_embeddings import ( @@ -79,3 +81,21 @@ def test_pgvector_with_metadatas_with_scores() -> None: ) output = docsearch.similarity_search_with_score("foo", k=1) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_pgvector_collection_with_metadata() -> None: + """Test end to end collection construction""" + pgvector = PGVector( + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding_function=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = Session(pgvector.connect()) + collection = pgvector.get_collection(session) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"}