2023-12-11 21:53:30 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import time
|
|
|
|
from itertools import repeat
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
|
|
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
|
|
|
|
|
|
|
|
class XataVectorStore(VectorStore):
|
|
|
|
"""`Xata` vector store.
|
|
|
|
|
|
|
|
It assumes you have a Xata database
|
|
|
|
created with the right schema. See the guide at:
|
|
|
|
https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
api_key: str,
|
|
|
|
db_url: str,
|
|
|
|
embedding: Embeddings,
|
|
|
|
table_name: str,
|
|
|
|
) -> None:
|
|
|
|
"""Initialize with Xata client."""
|
|
|
|
try:
|
2024-04-30 17:13:48 +00:00
|
|
|
from xata.client import XataClient
|
2023-12-11 21:53:30 +00:00
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Could not import xata python package. "
|
|
|
|
"Please install it with `pip install xata`."
|
|
|
|
)
|
|
|
|
self._client = XataClient(api_key=api_key, db_url=db_url)
|
|
|
|
self._embedding: Embeddings = embedding
|
|
|
|
self._table_name = table_name or "vectors"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def embeddings(self) -> Embeddings:
|
|
|
|
return self._embedding
|
|
|
|
|
|
|
|
def add_vectors(
|
|
|
|
self,
|
|
|
|
vectors: List[List[float]],
|
|
|
|
documents: List[Document],
|
|
|
|
ids: Optional[List[str]] = None,
|
|
|
|
) -> List[str]:
|
|
|
|
return self._add_vectors(vectors, documents, ids)
|
|
|
|
|
|
|
|
def add_texts(
|
|
|
|
self,
|
|
|
|
texts: Iterable[str],
|
|
|
|
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
|
|
|
ids: Optional[List[str]] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> List[str]:
|
|
|
|
ids = ids
|
|
|
|
docs = self._texts_to_documents(texts, metadatas)
|
|
|
|
|
|
|
|
vectors = self._embedding.embed_documents(list(texts))
|
|
|
|
return self.add_vectors(vectors, docs, ids)
|
|
|
|
|
|
|
|
def _add_vectors(
|
|
|
|
self,
|
|
|
|
vectors: List[List[float]],
|
|
|
|
documents: List[Document],
|
|
|
|
ids: Optional[List[str]] = None,
|
|
|
|
) -> List[str]:
|
|
|
|
"""Add vectors to the Xata database."""
|
|
|
|
|
|
|
|
rows: List[Dict[str, Any]] = []
|
|
|
|
for idx, embedding in enumerate(vectors):
|
|
|
|
row = {
|
|
|
|
"content": documents[idx].page_content,
|
|
|
|
"embedding": embedding,
|
|
|
|
}
|
|
|
|
if ids:
|
|
|
|
row["id"] = ids[idx]
|
|
|
|
for key, val in documents[idx].metadata.items():
|
|
|
|
if key not in ["id", "content", "embedding"]:
|
|
|
|
row[key] = val
|
|
|
|
rows.append(row)
|
|
|
|
|
|
|
|
# XXX: I would have liked to use the BulkProcessor here, but it
|
|
|
|
# doesn't return the IDs, which we need here. Manual chunking it is.
|
|
|
|
chunk_size = 1000
|
|
|
|
id_list: List[str] = []
|
|
|
|
for i in range(0, len(rows), chunk_size):
|
|
|
|
chunk = rows[i : i + chunk_size]
|
|
|
|
|
|
|
|
r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
|
|
|
|
id_list.extend(r["recordIDs"])
|
|
|
|
return id_list
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _texts_to_documents(
|
|
|
|
texts: Iterable[str],
|
|
|
|
metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
|
|
|
|
) -> List[Document]:
|
|
|
|
"""Return list of Documents from list of texts and metadatas."""
|
|
|
|
if metadatas is None:
|
|
|
|
metadatas = repeat({})
|
|
|
|
|
|
|
|
docs = [
|
|
|
|
Document(page_content=text, metadata=metadata)
|
|
|
|
for text, metadata in zip(texts, metadatas)
|
|
|
|
]
|
|
|
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_texts(
|
|
|
|
cls: Type["XataVectorStore"],
|
|
|
|
texts: List[str],
|
|
|
|
embedding: Embeddings,
|
|
|
|
metadatas: Optional[List[dict]] = None,
|
|
|
|
api_key: Optional[str] = None,
|
|
|
|
db_url: Optional[str] = None,
|
|
|
|
table_name: str = "vectors",
|
|
|
|
ids: Optional[List[str]] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> "XataVectorStore":
|
|
|
|
"""Return VectorStore initialized from texts and embeddings."""
|
|
|
|
|
|
|
|
if not api_key or not db_url:
|
|
|
|
raise ValueError("Xata api_key and db_url must be set.")
|
|
|
|
|
|
|
|
embeddings = embedding.embed_documents(texts)
|
|
|
|
ids = None # Xata will generate them for us
|
|
|
|
docs = cls._texts_to_documents(texts, metadatas)
|
|
|
|
|
|
|
|
vector_db = cls(
|
|
|
|
api_key=api_key,
|
|
|
|
db_url=db_url,
|
|
|
|
embedding=embedding,
|
|
|
|
table_name=table_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
vector_db._add_vectors(embeddings, docs, ids)
|
|
|
|
return vector_db
|
|
|
|
|
|
|
|
def similarity_search(
|
|
|
|
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
|
|
|
) -> List[Document]:
|
|
|
|
"""Return docs most similar to query.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
query: Text to look up documents similar to.
|
|
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List of Documents most similar to the query.
|
|
|
|
"""
|
|
|
|
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
|
|
|
|
documents = [d[0] for d in docs_and_scores]
|
|
|
|
return documents
|
|
|
|
|
|
|
|
def similarity_search_with_score(
|
|
|
|
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
|
|
|
) -> List[Tuple[Document, float]]:
|
|
|
|
"""Run similarity search with Chroma with distance.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
query (str): Query text to search for.
|
|
|
|
k (int): Number of results to return. Defaults to 4.
|
|
|
|
filter (Optional[dict]): Filter by metadata. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Tuple[Document, float]]: List of documents most similar to the query
|
|
|
|
text with distance in float.
|
|
|
|
"""
|
|
|
|
embedding = self._embedding.embed_query(query)
|
|
|
|
payload = {
|
|
|
|
"queryVector": embedding,
|
|
|
|
"column": "embedding",
|
|
|
|
"size": k,
|
|
|
|
}
|
|
|
|
if filter:
|
|
|
|
payload["filter"] = filter
|
|
|
|
r = self._client.data().vector_search(self._table_name, payload=payload)
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error running similarity search: {r.status_code} {r}")
|
|
|
|
hits = r["records"]
|
|
|
|
docs_and_scores = [
|
|
|
|
(
|
|
|
|
Document(
|
|
|
|
page_content=hit["content"],
|
|
|
|
metadata=self._extractMetadata(hit),
|
|
|
|
),
|
|
|
|
hit["xata"]["score"],
|
|
|
|
)
|
|
|
|
for hit in hits
|
|
|
|
]
|
|
|
|
return docs_and_scores
|
|
|
|
|
|
|
|
def _extractMetadata(self, record: dict) -> dict:
|
|
|
|
"""Extract metadata from a record. Filters out known columns."""
|
|
|
|
metadata = {}
|
|
|
|
for key, val in record.items():
|
|
|
|
if key not in ["id", "content", "embedding", "xata"]:
|
|
|
|
metadata[key] = val
|
|
|
|
return metadata
|
|
|
|
|
|
|
|
def delete(
|
|
|
|
self,
|
|
|
|
ids: Optional[List[str]] = None,
|
|
|
|
delete_all: Optional[bool] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> None:
|
|
|
|
"""Delete by vector IDs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ids: List of ids to delete.
|
|
|
|
delete_all: Delete all records in the table.
|
|
|
|
"""
|
|
|
|
if delete_all:
|
|
|
|
self._delete_all()
|
|
|
|
self.wait_for_indexing(ndocs=0)
|
|
|
|
elif ids is not None:
|
|
|
|
chunk_size = 500
|
|
|
|
for i in range(0, len(ids), chunk_size):
|
|
|
|
chunk = ids[i : i + chunk_size]
|
|
|
|
operations = [
|
|
|
|
{"delete": {"table": self._table_name, "id": id}} for id in chunk
|
|
|
|
]
|
|
|
|
self._client.records().transaction(payload={"operations": operations})
|
|
|
|
else:
|
|
|
|
raise ValueError("Either ids or delete_all must be set.")
|
|
|
|
|
|
|
|
def _delete_all(self) -> None:
|
|
|
|
"""Delete all records in the table."""
|
|
|
|
while True:
|
|
|
|
r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error running query: {r.status_code} {r}")
|
|
|
|
ids = [rec["id"] for rec in r["records"]]
|
|
|
|
if len(ids) == 0:
|
|
|
|
break
|
|
|
|
operations = [
|
|
|
|
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
|
|
|
]
|
|
|
|
self._client.records().transaction(payload={"operations": operations})
|
|
|
|
|
|
|
|
def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
|
|
|
|
"""Wait for the search index to contain a certain number of
|
|
|
|
documents. Useful in tests.
|
|
|
|
"""
|
|
|
|
start = time.time()
|
|
|
|
while True:
|
|
|
|
r = self._client.data().search_table(
|
|
|
|
self._table_name, payload={"query": "", "page": {"size": 0}}
|
|
|
|
)
|
|
|
|
if r.status_code != 200:
|
|
|
|
raise Exception(f"Error running search: {r.status_code} {r}")
|
|
|
|
if r["totalCount"] == ndocs:
|
|
|
|
break
|
|
|
|
if time.time() - start > timeout:
|
|
|
|
raise Exception("Timed out waiting for indexing to complete.")
|
|
|
|
time.sleep(0.5)
|