Update supabase.py, add filter to query (matches latest supabase docs & js) (#7721)

- Description: Update supabase to support optional filter argument (if
present, used, if not, doesn't break things)
- Tag maintainer: @rlancemartin, @eyurtsev

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
earonesty 2023-07-24 22:13:52 -04:00 committed by GitHub
parent 00de334f81
commit 59a7c5877a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ from itertools import repeat
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -74,7 +75,7 @@ class SupabaseVectorStore(VectorStore):
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict[Any, Any]]] = None, metadatas: Optional[List[Dict[Any, Any]]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
@ -125,30 +126,56 @@ class SupabaseVectorStore(VectorStore):
return self._add_vectors(self._client, self.table_name, vectors, documents, ids) return self._add_vectors(self._client, self.table_name, vectors, documents, ids)
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
vectors = self._embedding.embed_documents([query]) vectors = self._embedding.embed_documents([query])
return self.similarity_search_by_vector(vectors[0], k) return self.similarity_search_by_vector(
vectors[0], k=k, filter=filter, **kwargs
)
def similarity_search_by_vector( def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
result = self.similarity_search_by_vector_with_relevance_scores(embedding, k) result = self.similarity_search_by_vector_with_relevance_scores(
embedding, k=k, filter=filter, **kwargs
)
documents = [doc for doc, _ in result] documents = [doc for doc, _ in result]
return documents return documents
def similarity_search_with_relevance_scores( def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
vectors = self._embedding.embed_documents([query]) vectors = self._embedding.embed_documents([query])
return self.similarity_search_by_vector_with_relevance_scores(vectors[0], k) return self.similarity_search_by_vector_with_relevance_scores(
vectors[0], k=k, filter=filter
)
def match_args(
self, query: List[float], k: int, filter: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
ret = dict(query_embedding=query, match_count=k)
if filter:
ret["filter"] = filter
return ret
def similarity_search_by_vector_with_relevance_scores( def similarity_search_by_vector_with_relevance_scores(
self, query: List[float], k: int self, query: List[float], k: int, filter: Optional[Dict[str, Any]] = None
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
match_documents_params = dict(query_embedding=query, match_count=k) match_documents_params = self.match_args(query, k, filter)
res = self._client.rpc(self.query_name, match_documents_params).execute() res = self._client.rpc(self.query_name, match_documents_params).execute()
match_result = [ match_result = [
@ -166,9 +193,9 @@ class SupabaseVectorStore(VectorStore):
return match_result return match_result
def similarity_search_by_vector_returning_embeddings( def similarity_search_by_vector_returning_embeddings(
self, query: List[float], k: int self, query: List[float], k: int, filter: Optional[Dict[str, Any]] = None
) -> List[Tuple[Document, float, np.ndarray[np.float32, Any]]]: ) -> List[Tuple[Document, float, np.ndarray[np.float32, Any]]]:
match_documents_params = dict(query_embedding=query, match_count=k) match_documents_params = self.match_args(query, k, filter)
res = self._client.rpc(self.query_name, match_documents_params).execute() res = self._client.rpc(self.query_name, match_documents_params).execute()
match_result = [ match_result = [
@ -193,7 +220,7 @@ class SupabaseVectorStore(VectorStore):
@staticmethod @staticmethod
def _texts_to_documents( def _texts_to_documents(
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[Iterable[dict[Any, Any]]] = None, metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
) -> List[Document]: ) -> List[Document]:
"""Return list of Documents from list of texts and metadatas.""" """Return list of Documents from list of texts and metadatas."""
if metadatas is None: if metadatas is None:
@ -216,7 +243,7 @@ class SupabaseVectorStore(VectorStore):
) -> List[str]: ) -> List[str]:
"""Add vectors to Supabase table.""" """Add vectors to Supabase table."""
rows: List[dict[str, Any]] = [ rows: List[Dict[str, Any]] = [
{ {
"id": ids[idx], "id": ids[idx],
"content": documents[idx].page_content, "content": documents[idx].page_content,
@ -360,7 +387,7 @@ class SupabaseVectorStore(VectorStore):
if ids is None: if ids is None:
raise ValueError("No ids provided to delete.") raise ValueError("No ids provided to delete.")
rows: List[dict[str, Any]] = [ rows: List[Dict[str, Any]] = [
{ {
"id": id, "id": id,
} }