mirror of https://github.com/hwchase17/langchain
Add support for Xata as a vector store (#8822)
This adds support for [Xata](https://xata.io) (data platform based on Postgres) as a vector store. We have recently added [Xata to Langchain.js](https://github.com/hwchase17/langchainjs/pull/2125) and would love to have the equivalent in the Python project as well. The PR includes integration tests and a Jupyter notebook as docs. Please let me know if anything else would be needed or helpful. I have added the xata python SDK as an optional dependency. ## To run the integration tests You will need to create a DB in xata (see the docs), then run something like: ``` OPENAI_API_KEY=sk-... XATA_API_KEY=xau_... XATA_DB_URL='https://....xata.sh/db/langchain' poetry run pytest tests/integration_tests/vectorstores/test_xata.py ``` <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Philip Krauss <35487337+philkra@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/8870/head
parent
472f00ada7
commit
aeaef8f3a3
@ -0,0 +1,240 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Xata\n",
|
||||||
|
"\n",
|
||||||
|
"> [Xata](https://xata.io) is a serverless data platform, based on PostgreSQL. It provides a Python SDK for interacting with your database, and a UI for managing your data.\n",
|
||||||
|
"> Xata has a native vector type, which can be added to any table, and supports similarity search. LangChain inserts vectors directly to Xata, and queries it for the nearest neighbors of a given vector, so that you can use all the LangChain Embeddings integrations with Xata."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This notebook guides you how to use Xata as a VectorStore."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"### Create a database to use as a vector store\n",
|
||||||
|
"\n",
|
||||||
|
"In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`.\n",
|
||||||
|
"Create a table, again you can name it anything, but we will use `vectors`. Add the following columns via the UI:\n",
|
||||||
|
"\n",
|
||||||
|
"* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n",
|
||||||
|
"* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions.\n",
|
||||||
|
"* `search` of type \"Text\". This is used as a metadata column by this example.\n",
|
||||||
|
"* any other columns you want to use as metadata. They are populated from the `Document.metadata` object. For example, if in the `Document.metadata` object you have a `title` property, you can create a `title` column in the table and it will be populated.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's first install our dependencies:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install xata==1.0.0a7 openai tiktoken langchain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's load the OpenAI key to the environemnt. If you don't have one you can create an OpenAI account and create a key on this [page](https://platform.openai.com/account/api-keys)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Similarly, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"api_key = getpass.getpass(\"Xata API key: \")\n",
|
||||||
|
"db_url = input(\"Xata database URL (copy it from your DB settings):\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"from langchain.vectorstores.xata import XataVectorStore\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Create the Xata vector store\n",
|
||||||
|
"Let's import our test dataset:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
|
"docs = text_splitter.split_documents(documents)\n",
|
||||||
|
"\n",
|
||||||
|
"embeddings = OpenAIEmbeddings()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now create the actual vector store, backed by the Xata table."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vector_store = XataVectorStore.from_documents(docs, embeddings, api_key=api_key, db_url=db_url, table_name=\"vectors\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Similarity Search"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"found_docs = vector_store.similarity_search(query)\n",
|
||||||
|
"print(found_docs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Similarity Search with score (vector distance)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"result = vector_store.similarity_search_with_score(query)\n",
|
||||||
|
"for doc, score in result:\n",
|
||||||
|
" print(f\"document={doc}, score={score}\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -0,0 +1,263 @@
|
|||||||
|
"""Wrapper around Xata as a vector database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class XataVectorStore(VectorStore):
|
||||||
|
"""VectorStore for a Xata database. Assumes you have a Xata database
|
||||||
|
created with the right schema. See the guide at:
|
||||||
|
https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
db_url: str,
|
||||||
|
embedding: Embeddings,
|
||||||
|
table_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize with Xata client."""
|
||||||
|
try:
|
||||||
|
from xata.client import XataClient # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import xata python package. "
|
||||||
|
"Please install it with `pip install xata`."
|
||||||
|
)
|
||||||
|
self._client = XataClient(api_key=api_key, db_url=db_url)
|
||||||
|
self._embedding: Embeddings = embedding
|
||||||
|
self._table_name = table_name or "vectors"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Embeddings:
|
||||||
|
return self._embedding
|
||||||
|
|
||||||
|
def add_vectors(
|
||||||
|
self,
|
||||||
|
vectors: List[List[float]],
|
||||||
|
documents: List[Document],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
return self._add_vectors(vectors, documents, ids)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
ids = ids
|
||||||
|
docs = self._texts_to_documents(texts, metadatas)
|
||||||
|
|
||||||
|
vectors = self._embedding.embed_documents(list(texts))
|
||||||
|
return self.add_vectors(vectors, docs, ids)
|
||||||
|
|
||||||
|
def _add_vectors(
|
||||||
|
self,
|
||||||
|
vectors: List[List[float]],
|
||||||
|
documents: List[Document],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add vectors to the Xata database."""
|
||||||
|
|
||||||
|
rows: List[Dict[str, Any]] = []
|
||||||
|
for idx, embedding in enumerate(vectors):
|
||||||
|
row = {
|
||||||
|
"content": documents[idx].page_content,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
if ids:
|
||||||
|
row["id"] = ids[idx]
|
||||||
|
for key, val in documents[idx].metadata.items():
|
||||||
|
if key not in ["id", "content", "embedding"]:
|
||||||
|
row[key] = val
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
# XXX: I would have liked to use the BulkProcessor here, but it
|
||||||
|
# doesn't return the IDs, which we need here. Manual chunking it is.
|
||||||
|
chunk_size = 1000
|
||||||
|
id_list: List[str] = []
|
||||||
|
for i in range(0, len(rows), chunk_size):
|
||||||
|
chunk = rows[i : i + chunk_size]
|
||||||
|
|
||||||
|
r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
|
||||||
|
id_list.extend(r["recordIDs"])
|
||||||
|
return id_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _texts_to_documents(
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return list of Documents from list of texts and metadatas."""
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = repeat({})
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
Document(page_content=text, metadata=metadata)
|
||||||
|
for text, metadata in zip(texts, metadatas)
|
||||||
|
]
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type["XataVectorStore"],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
db_url: Optional[str] = None,
|
||||||
|
table_name: str = "vectors",
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "XataVectorStore":
|
||||||
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
|
|
||||||
|
if not api_key or not db_url:
|
||||||
|
raise ValueError("Xata api_key and db_url must be set.")
|
||||||
|
|
||||||
|
embeddings = embedding.embed_documents(texts)
|
||||||
|
ids = None # Xata will generate them for us
|
||||||
|
docs = cls._texts_to_documents(texts, metadatas)
|
||||||
|
|
||||||
|
vector_db = cls(
|
||||||
|
api_key=api_key,
|
||||||
|
db_url=db_url,
|
||||||
|
embedding=embedding,
|
||||||
|
table_name=table_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db._add_vectors(embeddings, docs, ids)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
|
||||||
|
documents = [d[0] for d in docs_and_scores]
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Run similarity search with Chroma with distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query text to search for.
|
||||||
|
k (int): Number of results to return. Defaults to 4.
|
||||||
|
filter (Optional[dict]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[Document, float]]: List of documents most similar to the query
|
||||||
|
text with distance in float.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding.embed_query(query)
|
||||||
|
payload = {
|
||||||
|
"queryVector": embedding,
|
||||||
|
"column": "embedding",
|
||||||
|
"size": k,
|
||||||
|
}
|
||||||
|
if filter:
|
||||||
|
payload["filter"] = filter
|
||||||
|
r = self._client.data().vector_search(self._table_name, payload=payload)
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise Exception(f"Error running similarity search: {r.status_code} {r}")
|
||||||
|
hits = r["records"]
|
||||||
|
docs_and_scores = [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=hit["content"],
|
||||||
|
metadata=self._extractMetadata(hit),
|
||||||
|
),
|
||||||
|
hit["xata"]["score"],
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
|
def _extractMetadata(self, record: dict) -> dict:
|
||||||
|
"""Extract metadata from a record. Filters out known columns."""
|
||||||
|
metadata = {}
|
||||||
|
for key, val in record.items():
|
||||||
|
if key not in ["id", "content", "embedding", "xata"]:
|
||||||
|
metadata[key] = val
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
delete_all: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Delete by vector IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
delete_all: Delete all records in the table.
|
||||||
|
"""
|
||||||
|
if delete_all:
|
||||||
|
self._delete_all()
|
||||||
|
self.wait_for_indexing(ndocs=0)
|
||||||
|
elif ids is not None:
|
||||||
|
chunk_size = 500
|
||||||
|
for i in range(0, len(ids), chunk_size):
|
||||||
|
chunk = ids[i : i + chunk_size]
|
||||||
|
operations = [
|
||||||
|
{"delete": {"table": self._table_name, "id": id}} for id in chunk
|
||||||
|
]
|
||||||
|
self._client.records().transaction(payload={"operations": operations})
|
||||||
|
else:
|
||||||
|
raise ValueError("Either ids or delete_all must be set.")
|
||||||
|
|
||||||
|
def _delete_all(self) -> None:
|
||||||
|
"""Delete all records in the table."""
|
||||||
|
while True:
|
||||||
|
r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||||
|
ids = [rec["id"] for rec in r["records"]]
|
||||||
|
if len(ids) == 0:
|
||||||
|
break
|
||||||
|
operations = [
|
||||||
|
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
||||||
|
]
|
||||||
|
self._client.records().transaction(payload={"operations": operations})
|
||||||
|
|
||||||
|
def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
|
||||||
|
"""Wait for the search index to contain a certain number of
|
||||||
|
documents. Useful in tests.
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
r = self._client.data().search_table(
|
||||||
|
self._table_name, payload={"query": "", "page": {"size": 0}}
|
||||||
|
)
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise Exception(f"Error running search: {r.status_code} {r}")
|
||||||
|
if r["totalCount"] == ndocs:
|
||||||
|
break
|
||||||
|
if time.time() - start > timeout:
|
||||||
|
raise Exception("Timed out waiting for indexing to complete.")
|
||||||
|
time.sleep(0.5)
|
@ -0,0 +1,56 @@
|
|||||||
|
"""Test Xata vector store functionality.
|
||||||
|
|
||||||
|
Before running this test, please create a Xata database by following
|
||||||
|
the instructions from:
|
||||||
|
https://python.langchain.com/docs/integrations/vectorstores/xata
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.vectorstores.xata import XataVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestXata:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls) -> None:
|
||||||
|
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
|
||||||
|
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
|
||||||
|
|
||||||
|
def test_similarity_search_without_metadata(
|
||||||
|
self, embedding_openai: OpenAIEmbeddings
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end constructions and search without metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = XataVectorStore.from_texts(
|
||||||
|
api_key=os.getenv("XATA_API_KEY"),
|
||||||
|
db_url=os.getenv("XATA_DB_URL"),
|
||||||
|
texts=texts,
|
||||||
|
embedding=embedding_openai,
|
||||||
|
)
|
||||||
|
docsearch.wait_for_indexing(ndocs=3)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
docsearch.delete(delete_all=True)
|
||||||
|
|
||||||
|
def test_similarity_search_with_metadata(
|
||||||
|
self, embedding_openai: OpenAIEmbeddings
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with a metadata filter.
|
||||||
|
|
||||||
|
This test requires a column named "a" of type integer to be present
|
||||||
|
in the Xata table."""
|
||||||
|
texts = ["foo", "foo", "foo"]
|
||||||
|
metadatas = [{"a": i} for i in range(len(texts))]
|
||||||
|
docsearch = XataVectorStore.from_texts(
|
||||||
|
api_key=os.getenv("XATA_API_KEY"),
|
||||||
|
db_url=os.getenv("XATA_DB_URL"),
|
||||||
|
texts=texts,
|
||||||
|
embedding=embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
)
|
||||||
|
docsearch.wait_for_indexing(ndocs=3)
|
||||||
|
output = docsearch.similarity_search("foo", k=1, filter={"a": 1})
|
||||||
|
assert output == [Document(page_content="foo", metadata={"a": 1})]
|
||||||
|
docsearch.delete(delete_all=True)
|
Loading…
Reference in New Issue