mirror of https://github.com/hwchase17/langchain
Add support for Xata as a vector store (#8822)
This adds support for [Xata](https://xata.io) (data platform based on Postgres) as a vector store. We have recently added [Xata to Langchain.js](https://github.com/hwchase17/langchainjs/pull/2125) and would love to have the equivalent in the Python project as well. The PR includes integration tests and a Jupyter notebook as docs. Please let me know if anything else would be needed or helpful. I have added the xata python SDK as an optional dependency. ## To run the integration tests You will need to create a DB in xata (see the docs), then run something like: ``` OPENAI_API_KEY=sk-... XATA_API_KEY=xau_... XATA_DB_URL='https://....xata.sh/db/langchain' poetry run pytest tests/integration_tests/vectorstores/test_xata.py ``` <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Philip Krauss <35487337+philkra@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/8870/head
parent
472f00ada7
commit
aeaef8f3a3
@ -0,0 +1,240 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Xata\n",
|
||||
"\n",
|
||||
"> [Xata](https://xata.io) is a serverless data platform, based on PostgreSQL. It provides a Python SDK for interacting with your database, and a UI for managing your data.\n",
|
||||
"> Xata has a native vector type, which can be added to any table, and supports similarity search. LangChain inserts vectors directly to Xata, and queries it for the nearest neighbors of a given vector, so that you can use all the LangChain Embeddings integrations with Xata."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook guides you how to use Xata as a VectorStore."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"### Create a database to use as a vector store\n",
|
||||
"\n",
|
||||
"In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`.\n",
|
||||
"Create a table, again you can name it anything, but we will use `vectors`. Add the following columns via the UI:\n",
|
||||
"\n",
|
||||
"* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n",
|
||||
"* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions.\n",
|
||||
"* `search` of type \"Text\". This is used as a metadata column by this example.\n",
|
||||
"* any other columns you want to use as metadata. They are populated from the `Document.metadata` object. For example, if in the `Document.metadata` object you have a `title` property, you can create a `title` column in the table and it will be populated.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first install our dependencies:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install xata==1.0.0a7 openai tiktoken langchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's load the OpenAI key to the environemnt. If you don't have one you can create an OpenAI account and create a key on this [page](https://platform.openai.com/account/api-keys)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Similarly, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = getpass.getpass(\"Xata API key: \")\n",
|
||||
"db_url = input(\"Xata database URL (copy it from your DB settings):\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.vectorstores.xata import XataVectorStore\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create the Xata vector store\n",
|
||||
"Let's import our test dataset:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now create the actual vector store, backed by the Xata table."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector_store = XataVectorStore.from_documents(docs, embeddings, api_key=api_key, db_url=db_url, table_name=\"vectors\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Similarity Search"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"found_docs = vector_store.similarity_search(query)\n",
|
||||
"print(found_docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Similarity Search with score (vector distance)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"result = vector_store.similarity_search_with_score(query)\n",
|
||||
"for doc, score in result:\n",
|
||||
" print(f\"document={doc}, score={score}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,263 @@
|
||||
"""Wrapper around Xata as a vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from itertools import repeat
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class XataVectorStore(VectorStore):
|
||||
"""VectorStore for a Xata database. Assumes you have a Xata database
|
||||
created with the right schema. See the guide at:
|
||||
https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
db_url: str,
|
||||
embedding: Embeddings,
|
||||
table_name: str,
|
||||
) -> None:
|
||||
"""Initialize with Xata client."""
|
||||
try:
|
||||
from xata.client import XataClient # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import xata python package. "
|
||||
"Please install it with `pip install xata`."
|
||||
)
|
||||
self._client = XataClient(api_key=api_key, db_url=db_url)
|
||||
self._embedding: Embeddings = embedding
|
||||
self._table_name = table_name or "vectors"
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self._embedding
|
||||
|
||||
def add_vectors(
|
||||
self,
|
||||
vectors: List[List[float]],
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
return self._add_vectors(vectors, documents, ids)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
ids = ids
|
||||
docs = self._texts_to_documents(texts, metadatas)
|
||||
|
||||
vectors = self._embedding.embed_documents(list(texts))
|
||||
return self.add_vectors(vectors, docs, ids)
|
||||
|
||||
def _add_vectors(
|
||||
self,
|
||||
vectors: List[List[float]],
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""Add vectors to the Xata database."""
|
||||
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for idx, embedding in enumerate(vectors):
|
||||
row = {
|
||||
"content": documents[idx].page_content,
|
||||
"embedding": embedding,
|
||||
}
|
||||
if ids:
|
||||
row["id"] = ids[idx]
|
||||
for key, val in documents[idx].metadata.items():
|
||||
if key not in ["id", "content", "embedding"]:
|
||||
row[key] = val
|
||||
rows.append(row)
|
||||
|
||||
# XXX: I would have liked to use the BulkProcessor here, but it
|
||||
# doesn't return the IDs, which we need here. Manual chunking it is.
|
||||
chunk_size = 1000
|
||||
id_list: List[str] = []
|
||||
for i in range(0, len(rows), chunk_size):
|
||||
chunk = rows[i : i + chunk_size]
|
||||
|
||||
r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
|
||||
id_list.extend(r["recordIDs"])
|
||||
return id_list
|
||||
|
||||
@staticmethod
|
||||
def _texts_to_documents(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
|
||||
) -> List[Document]:
|
||||
"""Return list of Documents from list of texts and metadatas."""
|
||||
if metadatas is None:
|
||||
metadatas = repeat({})
|
||||
|
||||
docs = [
|
||||
Document(page_content=text, metadata=metadata)
|
||||
for text, metadata in zip(texts, metadatas)
|
||||
]
|
||||
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type["XataVectorStore"],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
db_url: Optional[str] = None,
|
||||
table_name: str = "vectors",
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "XataVectorStore":
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
|
||||
if not api_key or not db_url:
|
||||
raise ValueError("Xata api_key and db_url must be set.")
|
||||
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
ids = None # Xata will generate them for us
|
||||
docs = cls._texts_to_documents(texts, metadatas)
|
||||
|
||||
vector_db = cls(
|
||||
api_key=api_key,
|
||||
db_url=db_url,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vector_db._add_vectors(embeddings, docs, ids)
|
||||
return vector_db
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
|
||||
documents = [d[0] for d in docs_and_scores]
|
||||
return documents
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with Chroma with distance.
|
||||
|
||||
Args:
|
||||
query (str): Query text to search for.
|
||||
k (int): Number of results to return. Defaults to 4.
|
||||
filter (Optional[dict]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of documents most similar to the query
|
||||
text with distance in float.
|
||||
"""
|
||||
embedding = self._embedding.embed_query(query)
|
||||
payload = {
|
||||
"queryVector": embedding,
|
||||
"column": "embedding",
|
||||
"size": k,
|
||||
}
|
||||
if filter:
|
||||
payload["filter"] = filter
|
||||
r = self._client.data().vector_search(self._table_name, payload=payload)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running similarity search: {r.status_code} {r}")
|
||||
hits = r["records"]
|
||||
docs_and_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["content"],
|
||||
metadata=self._extractMetadata(hit),
|
||||
),
|
||||
hit["xata"]["score"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
return docs_and_scores
|
||||
|
||||
def _extractMetadata(self, record: dict) -> dict:
|
||||
"""Extract metadata from a record. Filters out known columns."""
|
||||
metadata = {}
|
||||
for key, val in record.items():
|
||||
if key not in ["id", "content", "embedding", "xata"]:
|
||||
metadata[key] = val
|
||||
return metadata
|
||||
|
||||
def delete(
|
||||
self,
|
||||
ids: Optional[List[str]] = None,
|
||||
delete_all: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
delete_all: Delete all records in the table.
|
||||
"""
|
||||
if delete_all:
|
||||
self._delete_all()
|
||||
self.wait_for_indexing(ndocs=0)
|
||||
elif ids is not None:
|
||||
chunk_size = 500
|
||||
for i in range(0, len(ids), chunk_size):
|
||||
chunk = ids[i : i + chunk_size]
|
||||
operations = [
|
||||
{"delete": {"table": self._table_name, "id": id}} for id in chunk
|
||||
]
|
||||
self._client.records().transaction(payload={"operations": operations})
|
||||
else:
|
||||
raise ValueError("Either ids or delete_all must be set.")
|
||||
|
||||
def _delete_all(self) -> None:
|
||||
"""Delete all records in the table."""
|
||||
while True:
|
||||
r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running query: {r.status_code} {r}")
|
||||
ids = [rec["id"] for rec in r["records"]]
|
||||
if len(ids) == 0:
|
||||
break
|
||||
operations = [
|
||||
{"delete": {"table": self._table_name, "id": id}} for id in ids
|
||||
]
|
||||
self._client.records().transaction(payload={"operations": operations})
|
||||
|
||||
def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
|
||||
"""Wait for the search index to contain a certain number of
|
||||
documents. Useful in tests.
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
r = self._client.data().search_table(
|
||||
self._table_name, payload={"query": "", "page": {"size": 0}}
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise Exception(f"Error running search: {r.status_code} {r}")
|
||||
if r["totalCount"] == ndocs:
|
||||
break
|
||||
if time.time() - start > timeout:
|
||||
raise Exception("Timed out waiting for indexing to complete.")
|
||||
time.sleep(0.5)
|
@ -0,0 +1,56 @@
|
||||
"""Test Xata vector store functionality.
|
||||
|
||||
Before running this test, please create a Xata database by following
|
||||
the instructions from:
|
||||
https://python.langchain.com/docs/integrations/vectorstores/xata
|
||||
"""
|
||||
import os
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores.xata import XataVectorStore
|
||||
|
||||
|
||||
class TestXata:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
|
||||
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
|
||||
|
||||
def test_similarity_search_without_metadata(
|
||||
self, embedding_openai: OpenAIEmbeddings
|
||||
) -> None:
|
||||
"""Test end to end constructions and search without metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = XataVectorStore.from_texts(
|
||||
api_key=os.getenv("XATA_API_KEY"),
|
||||
db_url=os.getenv("XATA_DB_URL"),
|
||||
texts=texts,
|
||||
embedding=embedding_openai,
|
||||
)
|
||||
docsearch.wait_for_indexing(ndocs=3)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
docsearch.delete(delete_all=True)
|
||||
|
||||
def test_similarity_search_with_metadata(
|
||||
self, embedding_openai: OpenAIEmbeddings
|
||||
) -> None:
|
||||
"""Test end to end construction and search with a metadata filter.
|
||||
|
||||
This test requires a column named "a" of type integer to be present
|
||||
in the Xata table."""
|
||||
texts = ["foo", "foo", "foo"]
|
||||
metadatas = [{"a": i} for i in range(len(texts))]
|
||||
docsearch = XataVectorStore.from_texts(
|
||||
api_key=os.getenv("XATA_API_KEY"),
|
||||
db_url=os.getenv("XATA_DB_URL"),
|
||||
texts=texts,
|
||||
embedding=embedding_openai,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
docsearch.wait_for_indexing(ndocs=3)
|
||||
output = docsearch.similarity_search("foo", k=1, filter={"a": 1})
|
||||
assert output == [Document(page_content="foo", metadata={"a": 1})]
|
||||
docsearch.delete(delete_all=True)
|
Loading…
Reference in New Issue