Add PGVector collection metadata (#1887)

The `CollectionStore` for `PGVector` has a `cmetadata` field but it's
never used. This PR add the ability to save metadata information to the
collection.
tool-patch
Maurício Maia 1 year ago committed by GitHub
parent d08f940336
commit 2212520a6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import sqlalchemy import sqlalchemy
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import Mapped, Session, declarative_base, relationship from sqlalchemy.orm import Mapped, Session, declarative_base, relationship
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -29,7 +29,7 @@ class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection" __tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String) name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(sqlalchemy.JSON) cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship( embeddings = relationship(
"EmbeddingStore", "EmbeddingStore",
@ -57,7 +57,7 @@ class CollectionStore(BaseModel):
if collection: if collection:
return collection, created return collection, created
collection = cls(name=name, metadata=cmetadata) collection = cls(name=name, cmetadata=cmetadata)
session.add(collection) session.add(collection)
session.commit() session.commit()
created = True created = True
@ -121,6 +121,7 @@ class PGVector(VectorStore):
connection_string: str, connection_string: str,
embedding_function: Embeddings, embedding_function: Embeddings,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
@ -128,6 +129,7 @@ class PGVector(VectorStore):
self.connection_string = connection_string self.connection_string = connection_string
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.collection_name = collection_name self.collection_name = collection_name
self.collection_metadata = collection_metadata
self.distance_strategy = distance_strategy self.distance_strategy = distance_strategy
self.pre_delete_collection = pre_delete_collection self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
@ -168,7 +170,9 @@ class PGVector(VectorStore):
if self.pre_delete_collection: if self.pre_delete_collection:
self.delete_collection() self.delete_collection()
with Session(self._conn) as session: with Session(self._conn) as session:
CollectionStore.get_or_create(session, self.collection_name) CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None: def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection") self.logger.debug("Trying to delete collection")

@ -2,6 +2,8 @@
import os import os
from typing import List from typing import List
from sqlalchemy.orm import Session
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.vectorstores.pgvector import PGVector from langchain.vectorstores.pgvector import PGVector
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
@ -79,3 +81,21 @@ def test_pgvector_with_metadatas_with_scores() -> None:
) )
output = docsearch.similarity_search_with_score("foo", k=1) output = docsearch.similarity_search_with_score("foo", k=1)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
def test_pgvector_collection_with_metadata() -> None:
"""Test end to end collection construction"""
pgvector = PGVector(
collection_name="test_collection",
collection_metadata={"foo": "bar"},
embedding_function=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
session = Session(pgvector.connect())
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
else:
assert collection.name == "test_collection"
assert collection.cmetadata == {"foo": "bar"}

Loading…
Cancel
Save