mirror of https://github.com/hwchase17/langchain
Add sqlite-vss as a vector database (#10047)
This adds sqlite-vss as an option for a vector database. Contains the code and a few tests. Tests are passing and the library sqlite-vss is added as optional as explained in the contributing guidelines. I adjusted the code for lint/black/ and mypy. It looks that everything is currently passing. Adding sqlite-vss was mentioned in this issue: https://github.com/langchain-ai/langchain/issues/1019. Also mentioned here in the sqlite-vss repo for the curious: https://github.com/asg017/sqlite-vss/issues/66 Maintainer tag: @baskaryan --------- Co-authored-by: Philippe Oger <philippe.oger@adevinta.com>pull/10098/head
parent
c7a5504789
commit
3dafbd852e
@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLiteVSS(VectorStore):
|
||||
"""Wrapper around SQLite with vss extension as a vector database.
|
||||
To use, you should have the ``sqlite-vss`` python package installed.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.vectorstores import SQLiteVSS
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: str,
|
||||
connection: Optional[sqlite3.Connection],
|
||||
embedding: Embeddings,
|
||||
db_file: str = "vss.db",
|
||||
):
|
||||
"""Initialize with sqlite client with vss extension."""
|
||||
try:
|
||||
import sqlite_vss # noqa # pylint: disable=unused-import
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlite-vss python package. "
|
||||
"Please install it with `pip install sqlite-vss`."
|
||||
)
|
||||
|
||||
if not connection:
|
||||
connection = self.create_connection(db_file)
|
||||
|
||||
if not isinstance(embedding, Embeddings):
|
||||
warnings.warn("embeddings input must be Embeddings object.")
|
||||
|
||||
self._connection = connection
|
||||
self._table = table
|
||||
self._embedding = embedding
|
||||
|
||||
self.create_table_if_not_exists()
|
||||
|
||||
def create_table_if_not_exists(self) -> None:
|
||||
self._connection.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._table}
|
||||
(
|
||||
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT,
|
||||
metadata BLOB,
|
||||
text_embedding BLOB
|
||||
)
|
||||
;
|
||||
"""
|
||||
)
|
||||
self._connection.execute(
|
||||
f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vss_{self._table} USING vss0(
|
||||
text_embedding({self.get_dimensionality()})
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._connection.execute(
|
||||
f"""
|
||||
CREATE TRIGGER IF NOT EXISTS embed_text
|
||||
AFTER INSERT ON {self._table}
|
||||
BEGIN
|
||||
INSERT INTO vss_{self._table}(rowid, text_embedding)
|
||||
VALUES (new.rowid, new.text_embedding)
|
||||
;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
self._connection.commit()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add more texts to the vectorstore index.
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
kwargs: vectorstore specific parameters
|
||||
"""
|
||||
max_id = self._connection.execute(
|
||||
f"SELECT max(rowid) as rowid FROM {self._table}"
|
||||
).fetchone()["rowid"]
|
||||
if max_id is None: # no text added yet
|
||||
max_id = 0
|
||||
|
||||
embeds = self._embedding.embed_documents(list(texts))
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
data_input = [
|
||||
(text, json.dumps(metadata), json.dumps(embed))
|
||||
for text, metadata, embed in zip(texts, metadatas, embeds)
|
||||
]
|
||||
self._connection.executemany(
|
||||
f"INSERT INTO {self._table}(text, metadata, text_embedding) "
|
||||
f"VALUES (?,?,?)",
|
||||
data_input,
|
||||
)
|
||||
self._connection.commit()
|
||||
# pulling every ids we just inserted
|
||||
results = self._connection.execute(
|
||||
f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}"
|
||||
)
|
||||
return [row["rowid"] for row in results]
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
text,
|
||||
metadata,
|
||||
distance
|
||||
FROM {self._table} e
|
||||
INNER JOIN vss_{self._table} v on v.rowid = e.rowid
|
||||
WHERE vss_search(
|
||||
v.text_embedding,
|
||||
vss_search_params('{json.dumps(embedding)}', {k})
|
||||
)
|
||||
"""
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute(sql_query)
|
||||
results = cursor.fetchall()
|
||||
|
||||
documents = []
|
||||
for row in results:
|
||||
metadata = json.loads(row["metadata"]) or {}
|
||||
doc = Document(page_content=row["text"], metadata=metadata)
|
||||
documents.append((doc, row["distance"]))
|
||||
|
||||
return documents
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
embedding = self._embedding.embed_query(query)
|
||||
documents = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k
|
||||
)
|
||||
return [doc for doc, _ in documents]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query."""
|
||||
embedding = self._embedding.embed_query(query)
|
||||
documents = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k
|
||||
)
|
||||
return documents
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
documents = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k
|
||||
)
|
||||
return [doc for doc, _ in documents]
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[SQLiteVSS],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
table: str = "langchain",
|
||||
db_file: str = "vss.db",
|
||||
**kwargs: Any,
|
||||
) -> SQLiteVSS:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
connection = cls.create_connection(db_file)
|
||||
vss = cls(
|
||||
table=table, connection=connection, db_file=db_file, embedding=embedding
|
||||
)
|
||||
vss.add_texts(texts=texts, metadatas=metadatas)
|
||||
return vss
|
||||
|
||||
@staticmethod
|
||||
def create_connection(db_file: str) -> sqlite3.Connection:
|
||||
import sqlite_vss
|
||||
|
||||
connection = sqlite3.connect(db_file)
|
||||
connection.row_factory = sqlite3.Row
|
||||
connection.enable_load_extension(True)
|
||||
sqlite_vss.load(connection)
|
||||
connection.enable_load_extension(False)
|
||||
return connection
|
||||
|
||||
def get_dimensionality(self) -> int:
|
||||
"""
|
||||
Function that does a dummy embedding to figure out how many dimensions
|
||||
this embedding function returns. Needed for the virtual table DDL.
|
||||
"""
|
||||
dummy_text = "This is a dummy text"
|
||||
dummy_embedding = self._embedding.embed_query(dummy_text)
|
||||
return len(dummy_embedding)
|
@ -0,0 +1,58 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores import SQLiteVSS
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
|
||||
def _sqlite_vss_from_texts(
|
||||
metadatas: Optional[List[dict]] = None, drop: bool = True
|
||||
) -> SQLiteVSS:
|
||||
return SQLiteVSS.from_texts(
|
||||
fake_texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
table="test",
|
||||
db_file=":memory:",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("sqlite-vss")
|
||||
def test_sqlitevss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = _sqlite_vss_from_texts()
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={})]
|
||||
|
||||
|
||||
@pytest.mark.requires("sqlite-vss")
|
||||
def test_sqlitevss_with_score() -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _sqlite_vss_from_texts(metadatas=metadatas)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
distances = [o[1] for o in output]
|
||||
assert docs == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
]
|
||||
assert distances[0] < distances[1] < distances[2]
|
||||
|
||||
|
||||
@pytest.mark.requires("sqlite-vss")
|
||||
def test_sqlitevss_add_extra() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _sqlite_vss_from_texts(metadatas=metadatas)
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
Loading…
Reference in New Issue