Fixes scope of query Session in PGVector (#5194)

`vectorstore.PGVector`: The transactional boundary should be increased
to cover the query itself

Currently, within the `similarity_search_with_score_by_vector` the
transactional boundary (created via the `Session` call) does not include
the select query being made.

This can result in un-intended consequences when interacting with the
PGVector instance methods directly


---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Matt Wells 2023-05-24 18:37:45 +01:00 committed by GitHub
parent 52714cedd4
commit c173bf1c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -291,40 +291,43 @@ class PGVector(VectorStore):
if not collection:
raise ValueError("Collection not found")
filter_by = EmbeddingStore.collection_id == collection.uuid
filter_by = EmbeddingStore.collection_id == collection.uuid
if filter is not None:
filter_clauses = []
for key, value in filter.items():
IN = "in"
if isinstance(value, dict) and IN in map(str.lower, value):
value_case_insensitive = {k.lower(): v for k, v in value.items()}
filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_(
value_case_insensitive[IN]
)
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(
value
)
filter_clauses.append(filter_by_metadata)
if filter is not None:
filter_clauses = []
for key, value in filter.items():
IN = "in"
if isinstance(value, dict) and IN in map(str.lower, value):
value_case_insensitive = {
k.lower(): v for k, v in value.items()
}
filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_(
value_case_insensitive[IN]
)
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = EmbeddingStore.cmetadata[
key
].astext == str(value)
filter_clauses.append(filter_by_metadata)
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
results: List[QueryResult] = (
session.query(
EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
results: List[QueryResult] = (
session.query(
EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
)
.filter(filter_by)
.order_by(sqlalchemy.asc("distance"))
.join(
CollectionStore,
EmbeddingStore.collection_id == CollectionStore.uuid,
)
.limit(k)
.all()
)
.filter(filter_by)
.order_by(sqlalchemy.asc("distance"))
.join(
CollectionStore,
EmbeddingStore.collection_id == CollectionStore.uuid,
)
.limit(k)
.all()
)
docs = [
(
Document(