diff --git a/docs/extras/integrations/vectorstores/xata.ipynb b/docs/extras/integrations/vectorstores/xata.ipynb new file mode 100644 index 0000000000..601e8599f5 --- /dev/null +++ b/docs/extras/integrations/vectorstores/xata.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Xata\n", + "\n", + "> [Xata](https://xata.io) is a serverless data platform, based on PostgreSQL. It provides a Python SDK for interacting with your database, and a UI for managing your data.\n", + "> Xata has a native vector type, which can be added to any table, and supports similarity search. LangChain inserts vectors directly to Xata, and queries it for the nearest neighbors of a given vector, so that you can use all the LangChain Embeddings integrations with Xata." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook guides you how to use Xata as a VectorStore." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "### Create a database to use as a vector store\n", + "\n", + "In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`.\n", + "Create a table, again you can name it anything, but we will use `vectors`. Add the following columns via the UI:\n", + "\n", + "* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n", + "* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions.\n", + "* `search` of type \"Text\". This is used as a metadata column by this example.\n", + "* any other columns you want to use as metadata. They are populated from the `Document.metadata` object. For example, if in the `Document.metadata` object you have a `title` property, you can create a `title` column in the table and it will be populated.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's first install our dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "!pip install xata==1.0.0a7 openai tiktoken langchain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the OpenAI key to the environemnt. If you don't have one you can create an OpenAI account and create a key on this [page](https://platform.openai.com/account/api-keys)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "api_key = getpass.getpass(\"Xata API key: \")\n", + "db_url = input(\"Xata database URL (copy it from your DB settings):\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.document_loaders import TextLoader\n", + "from langchain.vectorstores.xata import XataVectorStore\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create the Xata vector store\n", + "Let's import our test dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embeddings = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now create the actual vector store, backed by the Xata table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "vector_store = XataVectorStore.from_documents(docs, embeddings, api_key=api_key, db_url=db_url, table_name=\"vectors\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Similarity Search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "found_docs = vector_store.similarity_search(query)\n", + "print(found_docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Similarity Search with score (vector distance)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "result = vector_store.similarity_search_with_score(query)\n", + "for doc, score in result:\n", + " print(f\"document={doc}, score={score}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/langchain/langchain/vectorstores/xata.py b/libs/langchain/langchain/vectorstores/xata.py new file mode 100644 index 0000000000..ef25bc1a6e --- /dev/null +++ b/libs/langchain/langchain/vectorstores/xata.py @@ -0,0 +1,263 @@ +"""Wrapper around Xata as a vector database.""" + +from __future__ import annotations + +import time +from itertools import repeat +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore + + +class XataVectorStore(VectorStore): + """VectorStore for a Xata database. Assumes you have a Xata database + created with the right schema. See the guide at: + https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore + + """ + + def __init__( + self, + api_key: str, + db_url: str, + embedding: Embeddings, + table_name: str, + ) -> None: + """Initialize with Xata client.""" + try: + from xata.client import XataClient # noqa: F401 + except ImportError: + raise ValueError( + "Could not import xata python package. " + "Please install it with `pip install xata`." + ) + self._client = XataClient(api_key=api_key, db_url=db_url) + self._embedding: Embeddings = embedding + self._table_name = table_name or "vectors" + + @property + def embeddings(self) -> Embeddings: + return self._embedding + + def add_vectors( + self, + vectors: List[List[float]], + documents: List[Document], + ids: Optional[List[str]] = None, + ) -> List[str]: + return self._add_vectors(vectors, documents, ids) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[Any, Any]]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + ids = ids + docs = self._texts_to_documents(texts, metadatas) + + vectors = self._embedding.embed_documents(list(texts)) + return self.add_vectors(vectors, docs, ids) + + def _add_vectors( + self, + vectors: List[List[float]], + documents: List[Document], + ids: Optional[List[str]] = None, + ) -> List[str]: + """Add vectors to the Xata database.""" + + rows: List[Dict[str, Any]] = [] + for idx, embedding in enumerate(vectors): + row = { + "content": documents[idx].page_content, + "embedding": embedding, + } + if ids: + row["id"] = ids[idx] + for key, val in documents[idx].metadata.items(): + if key not in ["id", "content", "embedding"]: + row[key] = val + rows.append(row) + + # XXX: I would have liked to use the BulkProcessor here, but it + # doesn't return the IDs, which we need here. Manual chunking it is. + chunk_size = 1000 + id_list: List[str] = [] + for i in range(0, len(rows), chunk_size): + chunk = rows[i : i + chunk_size] + + r = self._client.records().bulk_insert(self._table_name, {"records": chunk}) + if r.status_code != 200: + raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}") + id_list.extend(r["recordIDs"]) + return id_list + + @staticmethod + def _texts_to_documents( + texts: Iterable[str], + metadatas: Optional[Iterable[Dict[Any, Any]]] = None, + ) -> List[Document]: + """Return list of Documents from list of texts and metadatas.""" + if metadatas is None: + metadatas = repeat({}) + + docs = [ + Document(page_content=text, metadata=metadata) + for text, metadata in zip(texts, metadatas) + ] + + return docs + + @classmethod + def from_texts( + cls: Type["XataVectorStore"], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + api_key: Optional[str] = None, + db_url: Optional[str] = None, + table_name: str = "vectors", + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> "XataVectorStore": + """Return VectorStore initialized from texts and embeddings.""" + + if not api_key or not db_url: + raise ValueError("Xata api_key and db_url must be set.") + + embeddings = embedding.embed_documents(texts) + ids = None # Xata will generate them for us + docs = cls._texts_to_documents(texts, metadatas) + + vector_db = cls( + api_key=api_key, + db_url=db_url, + embedding=embedding, + table_name=table_name, + ) + + vector_db._add_vectors(embeddings, docs, ids) + return vector_db + + def similarity_search( + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + documents = [d[0] for d in docs_and_scores] + return documents + + def similarity_search_with_score( + self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Run similarity search with Chroma with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[dict]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of documents most similar to the query + text with distance in float. + """ + embedding = self._embedding.embed_query(query) + payload = { + "queryVector": embedding, + "column": "embedding", + "size": k, + } + if filter: + payload["filter"] = filter + r = self._client.data().vector_search(self._table_name, payload=payload) + if r.status_code != 200: + raise Exception(f"Error running similarity search: {r.status_code} {r}") + hits = r["records"] + docs_and_scores = [ + ( + Document( + page_content=hit["content"], + metadata=self._extractMetadata(hit), + ), + hit["xata"]["score"], + ) + for hit in hits + ] + return docs_and_scores + + def _extractMetadata(self, record: dict) -> dict: + """Extract metadata from a record. Filters out known columns.""" + metadata = {} + for key, val in record.items(): + if key not in ["id", "content", "embedding", "xata"]: + metadata[key] = val + return metadata + + def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + **kwargs: Any, + ) -> None: + """Delete by vector IDs. + + Args: + ids: List of ids to delete. + delete_all: Delete all records in the table. + """ + if delete_all: + self._delete_all() + self.wait_for_indexing(ndocs=0) + elif ids is not None: + chunk_size = 500 + for i in range(0, len(ids), chunk_size): + chunk = ids[i : i + chunk_size] + operations = [ + {"delete": {"table": self._table_name, "id": id}} for id in chunk + ] + self._client.records().transaction(payload={"operations": operations}) + else: + raise ValueError("Either ids or delete_all must be set.") + + def _delete_all(self) -> None: + """Delete all records in the table.""" + while True: + r = self._client.data().query(self._table_name, payload={"columns": ["id"]}) + if r.status_code != 200: + raise Exception(f"Error running query: {r.status_code} {r}") + ids = [rec["id"] for rec in r["records"]] + if len(ids) == 0: + break + operations = [ + {"delete": {"table": self._table_name, "id": id}} for id in ids + ] + self._client.records().transaction(payload={"operations": operations}) + + def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None: + """Wait for the search index to contain a certain number of + documents. Useful in tests. + """ + start = time.time() + while True: + r = self._client.data().search_table( + self._table_name, payload={"query": "", "page": {"size": 0}} + ) + if r.status_code != 200: + raise Exception(f"Error running search: {r.status_code} {r}") + if r["totalCount"] == ndocs: + break + if time.time() - start > timeout: + raise Exception("Timed out waiting for indexing to complete.") + time.sleep(0.5) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 535393cdb9..9cdde5d680 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -2439,6 +2439,21 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "dill" version = "0.3.6" @@ -4738,7 +4753,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -12098,7 +12112,7 @@ files = [ ] [package.dependencies] -accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\" or extra == \"torch\""} +accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\""} filelock = "*" huggingface-hub = ">=0.14.1,<1.0" numpy = ">=1.17" @@ -13070,6 +13084,24 @@ files = [ {file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"}, ] +[[package]] +name = "xata" +version = "1.0.0a7" +description = "Python client for Xata.io" +category = "main" +optional = true +python-versions = ">=3.8,<4.0" +files = [ + {file = "xata-1.0.0a7-py3-none-any.whl", hash = "sha256:1427e97bccddfd5fa8fba56ba993b2d78f1dc074e729d06ccc79c48d07bd023a"}, + {file = "xata-1.0.0a7.tar.gz", hash = "sha256:32769ddc22cc091bf133e66b91662185047fff05aa431e7c760b55cd0ddef6c3"}, +] + +[package.dependencies] +deprecation = ">=2.1.0,<3.0.0" +orjson = ">=3.8.1,<4.0.0" +python-dotenv = ">=0.21,<2.0" +requests = ">=2.28.1,<3.0.0" + [[package]] name = "xinference" version = "0.0.6" @@ -13570,15 +13602,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "amadeus", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha", "xinference"] -azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"] +all = ["anthropic", "clarifai", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "marqo", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "libdeeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "tigrisdb", "nebula3-python", "awadb", "esprima", "octoai-sdk", "rdflib", "amadeus", "xinference", "librosa", "python-arango"] +azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-search-documents"] clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xinference", "zep-python"] +extended-testing = ["amazon-textract-caller", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "jq", "pdfminer-six", "pgvector", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "mwparserfromhell", "mwxml", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "streamlit", "pyspark", "openai", "sympy", "rapidfuzz", "openai", "rank-bm25", "geopandas", "jinja2", "xinference", "gitpython", "newspaper3k", "feedparser", "xata"] javascript = ["esprima"] -llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"] +llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers", "xinference"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] text-helpers = ["chardet"] @@ -13586,4 +13618,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "39305f23d3d69179d247d643631133ac50f5e944d98518c8a56c5f839b8e7a04" +content-hash = "9c970917244d05f76c8592b986007e689495e94c6c47e2609677e2907dd0a312" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index cc04956ec5..5f88c9b88a 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -131,6 +131,7 @@ librosa = {version="^0.10.0.post2", optional = true } feedparser = {version = "^6.0.10", optional = true} newspaper3k = {version = "^0.2.8", optional = true} amazon-textract-caller = {version = "<2", optional = true} +xata = {version = "^1.0.0a7", optional = true} [tool.poetry.group.test.dependencies] # The only dependencies that should be added are @@ -369,6 +370,7 @@ extended_testing = [ "gitpython", "newspaper3k", "feedparser", + "xata", ] [tool.ruff] diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_xata.py b/libs/langchain/tests/integration_tests/vectorstores/test_xata.py new file mode 100644 index 0000000000..a4aed36654 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/test_xata.py @@ -0,0 +1,56 @@ +"""Test Xata vector store functionality. + +Before running this test, please create a Xata database by following +the instructions from: +https://python.langchain.com/docs/integrations/vectorstores/xata +""" +import os + +from langchain.docstore.document import Document +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores.xata import XataVectorStore + + +class TestXata: + @classmethod + def setup_class(cls) -> None: + assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set" + assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set" + + def test_similarity_search_without_metadata( + self, embedding_openai: OpenAIEmbeddings + ) -> None: + """Test end to end constructions and search without metadata.""" + texts = ["foo", "bar", "baz"] + docsearch = XataVectorStore.from_texts( + api_key=os.getenv("XATA_API_KEY"), + db_url=os.getenv("XATA_DB_URL"), + texts=texts, + embedding=embedding_openai, + ) + docsearch.wait_for_indexing(ndocs=3) + + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + docsearch.delete(delete_all=True) + + def test_similarity_search_with_metadata( + self, embedding_openai: OpenAIEmbeddings + ) -> None: + """Test end to end construction and search with a metadata filter. + + This test requires a column named "a" of type integer to be present + in the Xata table.""" + texts = ["foo", "foo", "foo"] + metadatas = [{"a": i} for i in range(len(texts))] + docsearch = XataVectorStore.from_texts( + api_key=os.getenv("XATA_API_KEY"), + db_url=os.getenv("XATA_DB_URL"), + texts=texts, + embedding=embedding_openai, + metadatas=metadatas, + ) + docsearch.wait_for_indexing(ndocs=3) + output = docsearch.similarity_search("foo", k=1, filter={"a": 1}) + assert output == [Document(page_content="foo", metadata={"a": 1})] + docsearch.delete(delete_all=True)