Add support for Xata as a vector store (#8822)

This adds support for [Xata](https://xata.io) (data platform based on
Postgres) as a vector store. We have recently added [Xata to
Langchain.js](https://github.com/hwchase17/langchainjs/pull/2125) and
would love to have the equivalent in the Python project as well.

The PR includes integration tests and a Jupyter notebook as docs. Please
let me know if anything else would be needed or helpful.

I have added the xata python SDK as an optional dependency.

## To run the integration tests

You will need to create a DB in xata (see the docs), then run something
like:

```
OPENAI_API_KEY=sk-... XATA_API_KEY=xau_... XATA_DB_URL='https://....xata.sh/db/langchain'  poetry run pytest tests/integration_tests/vectorstores/test_xata.py
```

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Philip Krauss <35487337+philkra@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8870/head
Tudor Golubenco 1 year ago committed by GitHub
parent 472f00ada7
commit aeaef8f3a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Xata\n",
"\n",
"> [Xata](https://xata.io) is a serverless data platform, based on PostgreSQL. It provides a Python SDK for interacting with your database, and a UI for managing your data.\n",
"> Xata has a native vector type, which can be added to any table, and supports similarity search. LangChain inserts vectors directly to Xata, and queries it for the nearest neighbors of a given vector, so that you can use all the LangChain Embeddings integrations with Xata."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook guides you how to use Xata as a VectorStore."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Create a database to use as a vector store\n",
"\n",
"In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`.\n",
"Create a table, again you can name it anything, but we will use `vectors`. Add the following columns via the UI:\n",
"\n",
"* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n",
"* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions.\n",
"* `search` of type \"Text\". This is used as a metadata column by this example.\n",
"* any other columns you want to use as metadata. They are populated from the `Document.metadata` object. For example, if in the `Document.metadata` object you have a `title` property, you can create a `title` column in the table and it will be populated.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's first install our dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"!pip install xata==1.0.0a7 openai tiktoken langchain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the OpenAI key to the environemnt. If you don't have one you can create an OpenAI account and create a key on this [page](https://platform.openai.com/account/api-keys)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"api_key = getpass.getpass(\"Xata API key: \")\n",
"db_url = input(\"Xata database URL (copy it from your DB settings):\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.vectorstores.xata import XataVectorStore\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the Xata vector store\n",
"Let's import our test dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now create the actual vector store, backed by the Xata table."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"vector_store = XataVectorStore.from_documents(docs, embeddings, api_key=api_key, db_url=db_url, table_name=\"vectors\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Similarity Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"found_docs = vector_store.similarity_search(query)\n",
"print(found_docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Similarity Search with score (vector distance)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = vector_store.similarity_search_with_score(query)\n",
"for doc, score in result:\n",
" print(f\"document={doc}, score={score}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -0,0 +1,263 @@
"""Wrapper around Xata as a vector database."""
from __future__ import annotations
import time
from itertools import repeat
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
class XataVectorStore(VectorStore):
"""VectorStore for a Xata database. Assumes you have a Xata database
created with the right schema. See the guide at:
https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
"""
def __init__(
self,
api_key: str,
db_url: str,
embedding: Embeddings,
table_name: str,
) -> None:
"""Initialize with Xata client."""
try:
from xata.client import XataClient # noqa: F401
except ImportError:
raise ValueError(
"Could not import xata python package. "
"Please install it with `pip install xata`."
)
self._client = XataClient(api_key=api_key, db_url=db_url)
self._embedding: Embeddings = embedding
self._table_name = table_name or "vectors"
@property
def embeddings(self) -> Embeddings:
return self._embedding
def add_vectors(
self,
vectors: List[List[float]],
documents: List[Document],
ids: Optional[List[str]] = None,
) -> List[str]:
return self._add_vectors(vectors, documents, ids)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[Any, Any]]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
ids = ids
docs = self._texts_to_documents(texts, metadatas)
vectors = self._embedding.embed_documents(list(texts))
return self.add_vectors(vectors, docs, ids)
def _add_vectors(
self,
vectors: List[List[float]],
documents: List[Document],
ids: Optional[List[str]] = None,
) -> List[str]:
"""Add vectors to the Xata database."""
rows: List[Dict[str, Any]] = []
for idx, embedding in enumerate(vectors):
row = {
"content": documents[idx].page_content,
"embedding": embedding,
}
if ids:
row["id"] = ids[idx]
for key, val in documents[idx].metadata.items():
if key not in ["id", "content", "embedding"]:
row[key] = val
rows.append(row)
# XXX: I would have liked to use the BulkProcessor here, but it
# doesn't return the IDs, which we need here. Manual chunking it is.
chunk_size = 1000
id_list: List[str] = []
for i in range(0, len(rows), chunk_size):
chunk = rows[i : i + chunk_size]
r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
if r.status_code != 200:
raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
id_list.extend(r["recordIDs"])
return id_list
@staticmethod
def _texts_to_documents(
texts: Iterable[str],
metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
) -> List[Document]:
"""Return list of Documents from list of texts and metadatas."""
if metadatas is None:
metadatas = repeat({})
docs = [
Document(page_content=text, metadata=metadata)
for text, metadata in zip(texts, metadatas)
]
return docs
@classmethod
def from_texts(
cls: Type["XataVectorStore"],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
api_key: Optional[str] = None,
db_url: Optional[str] = None,
table_name: str = "vectors",
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> "XataVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
if not api_key or not db_url:
raise ValueError("Xata api_key and db_url must be set.")
embeddings = embedding.embed_documents(texts)
ids = None # Xata will generate them for us
docs = cls._texts_to_documents(texts, metadatas)
vector_db = cls(
api_key=api_key,
db_url=db_url,
embedding=embedding,
table_name=table_name,
)
vector_db._add_vectors(embeddings, docs, ids)
return vector_db
def similarity_search(
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
documents = [d[0] for d in docs_and_scores]
return documents
def similarity_search_with_score(
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Run similarity search with Chroma with distance.
Args:
query (str): Query text to search for.
k (int): Number of results to return. Defaults to 4.
filter (Optional[dict]): Filter by metadata. Defaults to None.
Returns:
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
embedding = self._embedding.embed_query(query)
payload = {
"queryVector": embedding,
"column": "embedding",
"size": k,
}
if filter:
payload["filter"] = filter
r = self._client.data().vector_search(self._table_name, payload=payload)
if r.status_code != 200:
raise Exception(f"Error running similarity search: {r.status_code} {r}")
hits = r["records"]
docs_and_scores = [
(
Document(
page_content=hit["content"],
metadata=self._extractMetadata(hit),
),
hit["xata"]["score"],
)
for hit in hits
]
return docs_and_scores
def _extractMetadata(self, record: dict) -> dict:
"""Extract metadata from a record. Filters out known columns."""
metadata = {}
for key, val in record.items():
if key not in ["id", "content", "embedding", "xata"]:
metadata[key] = val
return metadata
def delete(
self,
ids: Optional[List[str]] = None,
delete_all: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
delete_all: Delete all records in the table.
"""
if delete_all:
self._delete_all()
self.wait_for_indexing(ndocs=0)
elif ids is not None:
chunk_size = 500
for i in range(0, len(ids), chunk_size):
chunk = ids[i : i + chunk_size]
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in chunk
]
self._client.records().transaction(payload={"operations": operations})
else:
raise ValueError("Either ids or delete_all must be set.")
def _delete_all(self) -> None:
"""Delete all records in the table."""
while True:
r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
ids = [rec["id"] for rec in r["records"]]
if len(ids) == 0:
break
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in ids
]
self._client.records().transaction(payload={"operations": operations})
def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
"""Wait for the search index to contain a certain number of
documents. Useful in tests.
"""
start = time.time()
while True:
r = self._client.data().search_table(
self._table_name, payload={"query": "", "page": {"size": 0}}
)
if r.status_code != 200:
raise Exception(f"Error running search: {r.status_code} {r}")
if r["totalCount"] == ndocs:
break
if time.time() - start > timeout:
raise Exception("Timed out waiting for indexing to complete.")
time.sleep(0.5)

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
# This file is automatically @generated by Poetry and should not be changed by hand.
[[package]]
name = "absl-py"
@ -2439,6 +2439,21 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "deprecation"
version = "2.1.0"
description = "A library to handle automated deprecations"
category = "main"
optional = true
python-versions = "*"
files = [
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
]
[package.dependencies]
packaging = "*"
[[package]]
name = "dill"
version = "0.3.6"
@ -4738,7 +4753,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
]
[[package]]
@ -12098,7 +12112,7 @@ files = [
]
[package.dependencies]
accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\" or extra == \"torch\""}
accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\""}
filelock = "*"
huggingface-hub = ">=0.14.1,<1.0"
numpy = ">=1.17"
@ -13070,6 +13084,24 @@ files = [
{file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"},
]
[[package]]
name = "xata"
version = "1.0.0a7"
description = "Python client for Xata.io"
category = "main"
optional = true
python-versions = ">=3.8,<4.0"
files = [
{file = "xata-1.0.0a7-py3-none-any.whl", hash = "sha256:1427e97bccddfd5fa8fba56ba993b2d78f1dc074e729d06ccc79c48d07bd023a"},
{file = "xata-1.0.0a7.tar.gz", hash = "sha256:32769ddc22cc091bf133e66b91662185047fff05aa431e7c760b55cd0ddef6c3"},
]
[package.dependencies]
deprecation = ">=2.1.0,<3.0.0"
orjson = ">=3.8.1,<4.0.0"
python-dotenv = ">=0.21,<2.0"
requests = ">=2.28.1,<3.0.0"
[[package]]
name = "xinference"
version = "0.0.6"
@ -13570,15 +13602,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "amadeus", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha", "xinference"]
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
all = ["anthropic", "clarifai", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "marqo", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "libdeeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "tigrisdb", "nebula3-python", "awadb", "esprima", "octoai-sdk", "rdflib", "amadeus", "xinference", "librosa", "python-arango"]
azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-search-documents"]
clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xinference", "zep-python"]
extended-testing = ["amazon-textract-caller", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "jq", "pdfminer-six", "pgvector", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "mwparserfromhell", "mwxml", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "streamlit", "pyspark", "openai", "sympy", "rapidfuzz", "openai", "rank-bm25", "geopandas", "jinja2", "xinference", "gitpython", "newspaper3k", "feedparser", "xata"]
javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers", "xinference"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
text-helpers = ["chardet"]
@ -13586,4 +13618,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "39305f23d3d69179d247d643631133ac50f5e944d98518c8a56c5f839b8e7a04"
content-hash = "9c970917244d05f76c8592b986007e689495e94c6c47e2609677e2907dd0a312"

@ -131,6 +131,7 @@ librosa = {version="^0.10.0.post2", optional = true }
feedparser = {version = "^6.0.10", optional = true}
newspaper3k = {version = "^0.2.8", optional = true}
amazon-textract-caller = {version = "<2", optional = true}
xata = {version = "^1.0.0a7", optional = true}
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
@ -369,6 +370,7 @@ extended_testing = [
"gitpython",
"newspaper3k",
"feedparser",
"xata",
]
[tool.ruff]

@ -0,0 +1,56 @@
"""Test Xata vector store functionality.
Before running this test, please create a Xata database by following
the instructions from:
https://python.langchain.com/docs/integrations/vectorstores/xata
"""
import os
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores.xata import XataVectorStore
class TestXata:
@classmethod
def setup_class(cls) -> None:
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
def test_similarity_search_without_metadata(
self, embedding_openai: OpenAIEmbeddings
) -> None:
"""Test end to end constructions and search without metadata."""
texts = ["foo", "bar", "baz"]
docsearch = XataVectorStore.from_texts(
api_key=os.getenv("XATA_API_KEY"),
db_url=os.getenv("XATA_DB_URL"),
texts=texts,
embedding=embedding_openai,
)
docsearch.wait_for_indexing(ndocs=3)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
docsearch.delete(delete_all=True)
def test_similarity_search_with_metadata(
self, embedding_openai: OpenAIEmbeddings
) -> None:
"""Test end to end construction and search with a metadata filter.
This test requires a column named "a" of type integer to be present
in the Xata table."""
texts = ["foo", "foo", "foo"]
metadatas = [{"a": i} for i in range(len(texts))]
docsearch = XataVectorStore.from_texts(
api_key=os.getenv("XATA_API_KEY"),
db_url=os.getenv("XATA_DB_URL"),
texts=texts,
embedding=embedding_openai,
metadatas=metadatas,
)
docsearch.wait_for_indexing(ndocs=3)
output = docsearch.similarity_search("foo", k=1, filter={"a": 1})
assert output == [Document(page_content="foo", metadata={"a": 1})]
docsearch.delete(delete_all=True)
Loading…
Cancel
Save