mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
229 lines
7.1 KiB
Python
229 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import warnings
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
)
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
if TYPE_CHECKING:
|
|
import sqlite3
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SQLiteVSS(VectorStore):
|
|
"""SQLite with VSS extension as a vector database.
|
|
|
|
To use, you should have the ``sqlite-vss`` python package installed.
|
|
Example:
|
|
.. code-block:: python
|
|
from langchain_community.vectorstores import SQLiteVSS
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
...
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
table: str,
|
|
connection: Optional[sqlite3.Connection],
|
|
embedding: Embeddings,
|
|
db_file: str = "vss.db",
|
|
):
|
|
"""Initialize with sqlite client with vss extension."""
|
|
try:
|
|
import sqlite_vss # noqa # pylint: disable=unused-import
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import sqlite-vss python package. "
|
|
"Please install it with `pip install sqlite-vss`."
|
|
)
|
|
|
|
if not connection:
|
|
connection = self.create_connection(db_file)
|
|
|
|
if not isinstance(embedding, Embeddings):
|
|
warnings.warn("embeddings input must be Embeddings object.")
|
|
|
|
self._connection = connection
|
|
self._table = table
|
|
self._embedding = embedding
|
|
|
|
self.create_table_if_not_exists()
|
|
|
|
def create_table_if_not_exists(self) -> None:
|
|
self._connection.execute(
|
|
f"""
|
|
CREATE TABLE IF NOT EXISTS {self._table}
|
|
(
|
|
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
text TEXT,
|
|
metadata BLOB,
|
|
text_embedding BLOB
|
|
)
|
|
;
|
|
"""
|
|
)
|
|
self._connection.execute(
|
|
f"""
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS vss_{self._table} USING vss0(
|
|
text_embedding({self.get_dimensionality()})
|
|
);
|
|
"""
|
|
)
|
|
self._connection.execute(
|
|
f"""
|
|
CREATE TRIGGER IF NOT EXISTS embed_text
|
|
AFTER INSERT ON {self._table}
|
|
BEGIN
|
|
INSERT INTO vss_{self._table}(rowid, text_embedding)
|
|
VALUES (new.rowid, new.text_embedding)
|
|
;
|
|
END;
|
|
"""
|
|
)
|
|
self._connection.commit()
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Add more texts to the vectorstore index.
|
|
Args:
|
|
texts: Iterable of strings to add to the vectorstore.
|
|
metadatas: Optional list of metadatas associated with the texts.
|
|
kwargs: vectorstore specific parameters
|
|
"""
|
|
max_id = self._connection.execute(
|
|
f"SELECT max(rowid) as rowid FROM {self._table}"
|
|
).fetchone()["rowid"]
|
|
if max_id is None: # no text added yet
|
|
max_id = 0
|
|
|
|
embeds = self._embedding.embed_documents(list(texts))
|
|
if not metadatas:
|
|
metadatas = [{} for _ in texts]
|
|
data_input = [
|
|
(text, json.dumps(metadata), json.dumps(embed))
|
|
for text, metadata, embed in zip(texts, metadatas, embeds)
|
|
]
|
|
self._connection.executemany(
|
|
f"INSERT INTO {self._table}(text, metadata, text_embedding) "
|
|
f"VALUES (?,?,?)",
|
|
data_input,
|
|
)
|
|
self._connection.commit()
|
|
# pulling every ids we just inserted
|
|
results = self._connection.execute(
|
|
f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}"
|
|
)
|
|
return [row["rowid"] for row in results]
|
|
|
|
def similarity_search_with_score_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
sql_query = f"""
|
|
SELECT
|
|
text,
|
|
metadata,
|
|
distance
|
|
FROM {self._table} e
|
|
INNER JOIN vss_{self._table} v on v.rowid = e.rowid
|
|
WHERE vss_search(
|
|
v.text_embedding,
|
|
vss_search_params('{json.dumps(embedding)}', {k})
|
|
)
|
|
"""
|
|
cursor = self._connection.cursor()
|
|
cursor.execute(sql_query)
|
|
results = cursor.fetchall()
|
|
|
|
documents = []
|
|
for row in results:
|
|
metadata = json.loads(row["metadata"]) or {}
|
|
doc = Document(page_content=row["text"], metadata=metadata)
|
|
documents.append((doc, row["distance"]))
|
|
|
|
return documents
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Return docs most similar to query."""
|
|
embedding = self._embedding.embed_query(query)
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k
|
|
)
|
|
return [doc for doc, _ in documents]
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Return docs most similar to query."""
|
|
embedding = self._embedding.embed_query(query)
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k
|
|
)
|
|
return documents
|
|
|
|
def similarity_search_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k
|
|
)
|
|
return [doc for doc, _ in documents]
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls: Type[SQLiteVSS],
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
table: str = "langchain",
|
|
db_file: str = "vss.db",
|
|
**kwargs: Any,
|
|
) -> SQLiteVSS:
|
|
"""Return VectorStore initialized from texts and embeddings."""
|
|
connection = cls.create_connection(db_file)
|
|
vss = cls(
|
|
table=table, connection=connection, db_file=db_file, embedding=embedding
|
|
)
|
|
vss.add_texts(texts=texts, metadatas=metadatas)
|
|
return vss
|
|
|
|
@staticmethod
|
|
def create_connection(db_file: str) -> sqlite3.Connection:
|
|
import sqlite3
|
|
|
|
import sqlite_vss
|
|
|
|
connection = sqlite3.connect(db_file)
|
|
connection.row_factory = sqlite3.Row
|
|
connection.enable_load_extension(True)
|
|
sqlite_vss.load(connection)
|
|
connection.enable_load_extension(False)
|
|
return connection
|
|
|
|
def get_dimensionality(self) -> int:
|
|
"""
|
|
Function that does a dummy embedding to figure out how many dimensions
|
|
this embedding function returns. Needed for the virtual table DDL.
|
|
"""
|
|
dummy_text = "This is a dummy text"
|
|
dummy_embedding = self._embedding.embed_query(dummy_text)
|
|
return len(dummy_embedding)
|