From 3dafbd852e4f8a31344f6916ec557eb9c56d84f2 Mon Sep 17 00:00:00 2001 From: Pihplipe Oegr Date: Fri, 1 Sep 2023 17:36:34 +0200 Subject: [PATCH] Add sqlite-vss as a vector database (#10047) This adds sqlite-vss as an option for a vector database. Contains the code and a few tests. Tests are passing and the library sqlite-vss is added as optional as explained in the contributing guidelines. I adjusted the code for lint/black/ and mypy. It looks that everything is currently passing. Adding sqlite-vss was mentioned in this issue: https://github.com/langchain-ai/langchain/issues/1019. Also mentioned here in the sqlite-vss repo for the curious: https://github.com/asg017/sqlite-vss/issues/66 Maintainer tag: @baskaryan --------- Co-authored-by: Philippe Oger --- .../langchain/vectorstores/__init__.py | 2 + .../langchain/vectorstores/sqlitevss.py | 222 ++++++++++++++++++ libs/langchain/poetry.lock | 30 ++- libs/langchain/pyproject.toml | 2 + .../vectorstores/test_sqlitevss.py | 58 +++++ 5 files changed, 302 insertions(+), 12 deletions(-) create mode 100644 libs/langchain/langchain/vectorstores/sqlitevss.py create mode 100644 libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 94bbc64a90..a130472b93 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -63,6 +63,7 @@ from langchain.vectorstores.rocksetdb import Rockset from langchain.vectorstores.scann import ScaNN from langchain.vectorstores.singlestoredb import SingleStoreDB from langchain.vectorstores.sklearn import SKLearnVectorStore +from langchain.vectorstores.sqlitevss import SQLiteVSS from langchain.vectorstores.starrocks import StarRocks from langchain.vectorstores.supabase import SupabaseVectorStore from langchain.vectorstores.tair import Tair @@ -125,6 +126,7 @@ __all__ = [ "ScaNN", "SingleStoreDB", "SingleStoreDB", + "SQLiteVSS", "StarRocks", "SupabaseVectorStore", "Tair", diff --git a/libs/langchain/langchain/vectorstores/sqlitevss.py b/libs/langchain/langchain/vectorstores/sqlitevss.py new file mode 100644 index 0000000000..d311ffafcb --- /dev/null +++ b/libs/langchain/langchain/vectorstores/sqlitevss.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import json +import logging +import sqlite3 +import warnings +from typing import ( + Any, + Iterable, + List, + Optional, + Tuple, + Type, +) + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore + +logger = logging.getLogger(__name__) + + +class SQLiteVSS(VectorStore): + """Wrapper around SQLite with vss extension as a vector database. + To use, you should have the ``sqlite-vss`` python package installed. + Example: + .. code-block:: python + from langchain.vectorstores import SQLiteVSS + from langchain.embeddings.openai import OpenAIEmbeddings + ... + """ + + def __init__( + self, + table: str, + connection: Optional[sqlite3.Connection], + embedding: Embeddings, + db_file: str = "vss.db", + ): + """Initialize with sqlite client with vss extension.""" + try: + import sqlite_vss # noqa # pylint: disable=unused-import + except ImportError: + raise ImportError( + "Could not import sqlite-vss python package. " + "Please install it with `pip install sqlite-vss`." + ) + + if not connection: + connection = self.create_connection(db_file) + + if not isinstance(embedding, Embeddings): + warnings.warn("embeddings input must be Embeddings object.") + + self._connection = connection + self._table = table + self._embedding = embedding + + self.create_table_if_not_exists() + + def create_table_if_not_exists(self) -> None: + self._connection.execute( + f""" + CREATE TABLE IF NOT EXISTS {self._table} + ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + text TEXT, + metadata BLOB, + text_embedding BLOB + ) + ; + """ + ) + self._connection.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vss_{self._table} USING vss0( + text_embedding({self.get_dimensionality()}) + ); + """ + ) + self._connection.execute( + f""" + CREATE TRIGGER IF NOT EXISTS embed_text + AFTER INSERT ON {self._table} + BEGIN + INSERT INTO vss_{self._table}(rowid, text_embedding) + VALUES (new.rowid, new.text_embedding) + ; + END; + """ + ) + self._connection.commit() + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + """Add more texts to the vectorstore index. + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + max_id = self._connection.execute( + f"SELECT max(rowid) as rowid FROM {self._table}" + ).fetchone()["rowid"] + if max_id is None: # no text added yet + max_id = 0 + + embeds = self._embedding.embed_documents(list(texts)) + if not metadatas: + metadatas = [{} for _ in texts] + data_input = [ + (text, json.dumps(metadata), json.dumps(embed)) + for text, metadata, embed in zip(texts, metadatas, embeds) + ] + self._connection.executemany( + f"INSERT INTO {self._table}(text, metadata, text_embedding) " + f"VALUES (?,?,?)", + data_input, + ) + self._connection.commit() + # pulling every ids we just inserted + results = self._connection.execute( + f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}" + ) + return [row["rowid"] for row in results] + + def similarity_search_with_score_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + sql_query = f""" + SELECT + text, + metadata, + distance + FROM {self._table} e + INNER JOIN vss_{self._table} v on v.rowid = e.rowid + WHERE vss_search( + v.text_embedding, + vss_search_params('{json.dumps(embedding)}', {k}) + ) + """ + cursor = self._connection.cursor() + cursor.execute(sql_query) + results = cursor.fetchall() + + documents = [] + for row in results: + metadata = json.loads(row["metadata"]) or {} + doc = Document(page_content=row["text"], metadata=metadata) + documents.append((doc, row["distance"])) + + return documents + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query.""" + embedding = self._embedding.embed_query(query) + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k + ) + return [doc for doc, _ in documents] + + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query.""" + embedding = self._embedding.embed_query(query) + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k + ) + return documents + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k + ) + return [doc for doc, _ in documents] + + @classmethod + def from_texts( + cls: Type[SQLiteVSS], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + table: str = "langchain", + db_file: str = "vss.db", + **kwargs: Any, + ) -> SQLiteVSS: + """Return VectorStore initialized from texts and embeddings.""" + connection = cls.create_connection(db_file) + vss = cls( + table=table, connection=connection, db_file=db_file, embedding=embedding + ) + vss.add_texts(texts=texts, metadatas=metadatas) + return vss + + @staticmethod + def create_connection(db_file: str) -> sqlite3.Connection: + import sqlite_vss + + connection = sqlite3.connect(db_file) + connection.row_factory = sqlite3.Row + connection.enable_load_extension(True) + sqlite_vss.load(connection) + connection.enable_load_extension(False) + return connection + + def get_dimensionality(self) -> int: + """ + Function that does a dummy embedding to figure out how many dimensions + this embedding function returns. Needed for the virtual table DDL. + """ + dummy_text = "This is a dummy text" + dummy_embedding = self._embedding.embed_query(dummy_text) + return len(dummy_embedding) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index e21acf426b..d49de58a10 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -3542,6 +3542,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -8148,10 +8149,8 @@ description = "Fast and Safe Tensor serialization" optional = true python-versions = "*" files = [ - {file = "safetensors-0.3.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:4c7827b64b1da3f082301b5f5a34331b8313104c14f257099a12d32ac621c5cd"}, {file = "safetensors-0.3.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b6a66989075c2891d743153e8ba9ca84ee7232c8539704488f454199b8b8f84d"}, {file = "safetensors-0.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:670d6bc3a3b377278ce2971fa7c36ebc0a35041c4ea23b9df750a39380800195"}, - {file = "safetensors-0.3.2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:67ef2cc747c88e3a8d8e4628d715874c0366a8ff1e66713a9d42285a429623ad"}, {file = "safetensors-0.3.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:564f42838721925b5313ae864ba6caa6f4c80a9fbe63cf24310c3be98ab013cd"}, {file = "safetensors-0.3.2-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:7f80af7e4ab3188daaff12d43d078da3017a90d732d38d7af4eb08b6ca2198a5"}, {file = "safetensors-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec30d78f20f1235b252d59cbb9755beb35a1fde8c24c89b3c98e6a1804cfd432"}, @@ -8160,9 +8159,7 @@ files = [ {file = "safetensors-0.3.2-cp310-cp310-win32.whl", hash = "sha256:2961c1243fd0da46aa6a1c835305cc4595486f8ac64632a604d0eb5f2de76175"}, {file = "safetensors-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c813920482c337d1424d306e1b05824a38e3ef94303748a0a287dea7a8c4f805"}, {file = "safetensors-0.3.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:707df34bd9b9047e97332136ad98e57028faeccdb9cfe1c3b52aba5964cc24bf"}, - {file = "safetensors-0.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:23d1d9f74208c9dfdf852a9f986dac63e40092385f84bf0789d599efa8e6522f"}, {file = "safetensors-0.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:becc5bb85b2947eae20ed23b407ebfd5277d9a560f90381fe2c42e6c043677ba"}, - {file = "safetensors-0.3.2-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:c1913c6c549b1805e924f307159f0ee97b73ae3ce150cd2401964da015e0fa0b"}, {file = "safetensors-0.3.2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:30a75707be5cc9686490bde14b9a371cede4af53244ea72b340cfbabfffdf58a"}, {file = "safetensors-0.3.2-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:54ad6af663e15e2b99e2ea3280981b7514485df72ba6d014dc22dae7ba6a5e6c"}, {file = "safetensors-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37764b3197656ef507a266c453e909a3477dabc795962b38e3ad28226f53153b"}, @@ -8170,28 +8167,22 @@ files = [ {file = "safetensors-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0fac127ff8fb04834da5c6d85a8077e6a1c9180a11251d96f8068db922a17"}, {file = "safetensors-0.3.2-cp311-cp311-win32.whl", hash = "sha256:155b82dbe2b0ebff18cde3f76b42b6d9470296e92561ef1a282004d449fa2b4c"}, {file = "safetensors-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:a86428d196959619ce90197731be9391b5098b35100a7228ef4643957648f7f5"}, - {file = "safetensors-0.3.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:91e796b6e465d9ffaca4c411d749f236c211e257f3a8e9b25a5ffc1a42d3bfa7"}, {file = "safetensors-0.3.2-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:c1f8ab41ed735c5b581f451fd15d9602ff51aa88044bfa933c5fa4b1d0c644d1"}, - {file = "safetensors-0.3.2-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:e6a8ff5652493598c45cd27f5613c193d3f15e76e0f81613d399c487a7b8cc50"}, {file = "safetensors-0.3.2-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:bc9cfb3c9ea2aec89685b4d656f9f2296f0f0d67ecf2bebf950870e3be89b3db"}, {file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ace5d471e3d78e0d93f952707d808b5ab5eac77ddb034ceb702e602e9acf2be9"}, {file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de3e20a388b444381bcda1a3193cce51825ddca277e4cf3ed1fe8d9b2d5722cd"}, {file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d7d70d48585fe8df00725aa788f2e64fd24a4c9ae07cd6be34f6859d0f89a9c"}, {file = "safetensors-0.3.2-cp37-cp37m-win32.whl", hash = "sha256:6ff59bc90cdc857f68b1023be9085fda6202bbe7f2fd67d06af8f976d6adcc10"}, {file = "safetensors-0.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:8b05c93da15fa911763a89281906ca333ed800ab0ef1c7ce53317aa1a2322f19"}, - {file = "safetensors-0.3.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:94857abc019b49a22a0065cc7741c48fb788aa7d8f3f4690c092c56090227abe"}, {file = "safetensors-0.3.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:8969cfd9e8d904e8d3c67c989e1bd9a95e3cc8980d4f95e4dcd43c299bb94253"}, - {file = "safetensors-0.3.2-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:da482fa011dc88fe7376d8f8b42c0ccef2f260e0cbc847ceca29c708bf75a868"}, {file = "safetensors-0.3.2-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:f54148ac027556eb02187e9bc1556c4d916c99ca3cb34ca36a7d304d675035c1"}, {file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caec25fedbcf73f66c9261984f07885680f71417fc173f52279276c7f8a5edd3"}, {file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50224a1d99927ccf3b75e27c3d412f7043280431ab100b4f08aad470c37cf99a"}, {file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa98f49e95f02eb750d32c4947e7d5aa43883149ebd0414920866446525b70f0"}, {file = "safetensors-0.3.2-cp38-cp38-win32.whl", hash = "sha256:33409df5e28a83dc5cc5547a3ac17c0f1b13a1847b1eb3bc4b3be0df9915171e"}, {file = "safetensors-0.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:e04a7cbbb3856159ab99e3adb14521544f65fcb8548cce773a1435a0f8d78d27"}, - {file = "safetensors-0.3.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:f39f3d951543b594c6bc5082149d994c47ca487fd5d55b4ce065ab90441aa334"}, {file = "safetensors-0.3.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:7c864cf5dcbfb608c5378f83319c60cc9c97263343b57c02756b7613cd5ab4dd"}, {file = "safetensors-0.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:14e8c19d6dc51d4f70ee33c46aff04c8ba3f95812e74daf8036c24bc86e75cae"}, - {file = "safetensors-0.3.2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:41b10b0a6dfe8fdfbe4b911d64717d5647e87fbd7377b2eb3d03fb94b59810ea"}, {file = "safetensors-0.3.2-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:042a60f633c3c7009fdf6a7c182b165cb7283649d2a1e9c7a4a1c23454bd9a5b"}, {file = "safetensors-0.3.2-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:fafd95e5ef41e8f312e2a32b7031f7b9b2a621b255f867b221f94bb2e9f51ae8"}, {file = "safetensors-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ed77cf358abce2307f03634694e0b2a29822e322a1623e0b1aa4b41e871bf8b"}, @@ -8706,6 +8697,21 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3-binary"] +[[package]] +name = "sqlite-vss" +version = "0.1.2" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "sqlite_vss-0.1.2-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:9eefa4207f8b522e32b2747fce44422c773e36710bf807613795218c7ba125f0"}, + {file = "sqlite_vss-0.1.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:84994eaf7fe700218b258422358c4536a6aca39b96026c308b28630967f954c4"}, + {file = "sqlite_vss-0.1.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:e44f03bc4cb214bb77b206519abfb623e3e4795967a569218e288927a7715806"}, +] + +[package.extras] +test = ["pytest"] + [[package]] name = "sqlitedict" version = "2.1.0" @@ -10437,7 +10443,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xmltodict"] +extended-testing = ["amazon-textract-caller", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tqdm", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -10447,4 +10453,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "43a6bd42efc0baf917418087f788aaf3b1bc793cb4aa81de99c52ed6a7d54d26" +content-hash = "1fbea4c22b1df46fd1062b58657321ef49b4c503851859d3ad1109acf427df9d" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 93b78e2d30..bc72febbd1 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -127,6 +127,7 @@ xata = {version = "^1.0.0a7", optional = true} xmltodict = {version = "^0.13.0", optional = true} markdownify = {version = "^0.11.6", optional = true} assemblyai = {version = "^0.17.0", optional = true} +sqlite-vss = {version = "^0.1.2", optional = true} [tool.poetry.group.test.dependencies] @@ -341,6 +342,7 @@ extended_testing = [ "faiss-cpu", "openapi-schema-pydantic", "markdownify", + "sqlite-vss", ] [tool.ruff] diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py b/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py new file mode 100644 index 0000000000..0896706d89 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py @@ -0,0 +1,58 @@ +from typing import List, Optional + +import pytest + +from langchain.docstore.document import Document +from langchain.vectorstores import SQLiteVSS +from tests.integration_tests.vectorstores.fake_embeddings import ( + FakeEmbeddings, + fake_texts, +) + + +def _sqlite_vss_from_texts( + metadatas: Optional[List[dict]] = None, drop: bool = True +) -> SQLiteVSS: + return SQLiteVSS.from_texts( + fake_texts, + FakeEmbeddings(), + metadatas=metadatas, + table="test", + db_file=":memory:", + ) + + +@pytest.mark.requires("sqlite-vss") +def test_sqlitevss() -> None: + """Test end to end construction and search.""" + docsearch = _sqlite_vss_from_texts() + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={})] + + +@pytest.mark.requires("sqlite-vss") +def test_sqlitevss_with_score() -> None: + """Test end to end construction and search with scores and IDs.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _sqlite_vss_from_texts(metadatas=metadatas) + output = docsearch.similarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + distances = [o[1] for o in output] + assert docs == [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="bar", metadata={"page": 1}), + Document(page_content="baz", metadata={"page": 2}), + ] + assert distances[0] < distances[1] < distances[2] + + +@pytest.mark.requires("sqlite-vss") +def test_sqlitevss_add_extra() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _sqlite_vss_from_texts(metadatas=metadatas) + docsearch.add_texts(texts, metadatas) + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6