mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
060987d755
- **Description:** Add LSH-based indexing to the Yellowbrick vector store module - **Twitter handle:** @markcusack --------- Co-authored-by: markcusack <markcusack@markcusacksmac.lan> Co-authored-by: markcusack <markcusack@Mark-Cusack-sMac.local> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
974 lines
34 KiB
Python
974 lines
34 KiB
Python
from __future__ import annotations
|
|
|
|
import atexit
|
|
import csv
|
|
import enum
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from io import StringIO
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
)
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from langchain_community.docstore.document import Document
|
|
|
|
if TYPE_CHECKING:
|
|
from psycopg2.extensions import connection as PgConnection
|
|
from psycopg2.extensions import cursor as PgCursor
|
|
|
|
|
|
class Yellowbrick(VectorStore):
|
|
"""Yellowbrick as a vector database.
|
|
Example:
|
|
.. code-block:: python
|
|
from langchain_community.vectorstores import Yellowbrick
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
...
|
|
"""
|
|
|
|
class IndexType(str, enum.Enum):
|
|
"""Enumerator for the supported Index types within Yellowbrick."""
|
|
|
|
NONE = "none"
|
|
LSH = "lsh"
|
|
|
|
class IndexParams:
|
|
"""Parameters for configuring a Yellowbrick index."""
|
|
|
|
def __init__(
|
|
self,
|
|
index_type: Optional["Yellowbrick.IndexType"] = None,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
):
|
|
if index_type is None:
|
|
index_type = Yellowbrick.IndexType.NONE
|
|
self.index_type = index_type
|
|
self.params = params or {}
|
|
|
|
def get_param(self, key: str, default: Any = None) -> Any:
|
|
return self.params.get(key, default)
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
connection_string: str,
|
|
table: str,
|
|
*,
|
|
schema: Optional[str] = None,
|
|
logger: Optional[logging.Logger] = None,
|
|
drop: bool = False,
|
|
) -> None:
|
|
"""Initialize with yellowbrick client.
|
|
Args:
|
|
embedding: Embedding operator
|
|
connection_string: Format 'postgres://username:password@host:port/database'
|
|
table: Table used to store / retrieve embeddings from
|
|
"""
|
|
from psycopg2 import extras
|
|
|
|
extras.register_uuid()
|
|
|
|
if logger:
|
|
self.logger = logger
|
|
else:
|
|
self.logger = logging.getLogger(__name__)
|
|
self.logger.setLevel(logging.ERROR)
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
self.logger.addHandler(handler)
|
|
|
|
if not isinstance(embedding, Embeddings):
|
|
self.logger.error("embeddings input must be Embeddings object.")
|
|
return
|
|
|
|
self.LSH_INDEX_TABLE: str = "_lsh_index"
|
|
self.LSH_HYPERPLANE_TABLE: str = "_lsh_hyperplane"
|
|
self.CONTENT_TABLE: str = "_content"
|
|
|
|
self.connection_string = connection_string
|
|
self.connection = Yellowbrick.DatabaseConnection(connection_string, self.logger)
|
|
atexit.register(self.connection.close_connection)
|
|
|
|
self._schema = schema
|
|
self._table = table
|
|
self._embedding = embedding
|
|
self._max_embedding_len = None
|
|
self._check_database_utf8()
|
|
|
|
with self.connection.get_cursor() as cursor:
|
|
if drop:
|
|
self.drop(table=self._table, schema=self._schema, cursor=cursor)
|
|
self.drop(
|
|
table=self._table + self.CONTENT_TABLE,
|
|
schema=self._schema,
|
|
cursor=cursor,
|
|
)
|
|
self._drop_lsh_index_tables(cursor)
|
|
|
|
self._create_schema(cursor)
|
|
self._create_table(cursor)
|
|
|
|
class DatabaseConnection:
|
|
_instance = None
|
|
_connection_string: str
|
|
_connection: Optional["PgConnection"] = None
|
|
_logger: logging.Logger
|
|
|
|
def __new__(
|
|
cls, connection_string: str, logger: logging.Logger
|
|
) -> "Yellowbrick.DatabaseConnection":
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._connection_string = connection_string
|
|
cls._instance._logger = logger
|
|
return cls._instance
|
|
|
|
def close_connection(self) -> None:
|
|
if self._connection and not self._connection.closed:
|
|
self._connection.close()
|
|
self._connection = None
|
|
|
|
def get_connection(self) -> "PgConnection":
|
|
import psycopg2
|
|
|
|
if not self._connection or self._connection.closed:
|
|
self._connection = psycopg2.connect(self._connection_string)
|
|
self._connection.autocommit = False
|
|
|
|
return self._connection
|
|
|
|
@contextmanager
|
|
def get_managed_connection(self) -> Generator["PgConnection", None, None]:
|
|
from psycopg2 import DatabaseError
|
|
|
|
conn = self.get_connection()
|
|
try:
|
|
yield conn
|
|
except DatabaseError as e:
|
|
conn.rollback()
|
|
self._logger.error(
|
|
"Database error occurred, rolling back transaction.", exc_info=True
|
|
)
|
|
raise RuntimeError("Database transaction failed.") from e
|
|
else:
|
|
conn.commit()
|
|
|
|
@contextmanager
|
|
def get_cursor(self) -> Generator["PgCursor", None, None]:
|
|
with self.get_managed_connection() as conn:
|
|
cursor = conn.cursor()
|
|
try:
|
|
yield cursor
|
|
finally:
|
|
cursor.close()
|
|
|
|
def _create_schema(self, cursor: "PgCursor") -> None:
|
|
"""
|
|
Helper function: create schema if not exists
|
|
"""
|
|
from psycopg2 import sql
|
|
|
|
if self._schema:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
CREATE SCHEMA IF NOT EXISTS {s}
|
|
"""
|
|
).format(
|
|
s=sql.Identifier(self._schema),
|
|
)
|
|
)
|
|
|
|
def _create_table(self, cursor: "PgCursor") -> None:
|
|
"""
|
|
Helper function: create table if not exists
|
|
"""
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
t = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE)
|
|
c = sql.Identifier(self._table + self.CONTENT_TABLE + "_pk_doc_id")
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS {t} (
|
|
doc_id UUID NOT NULL,
|
|
text VARCHAR(60000) NOT NULL,
|
|
metadata VARCHAR(1024) NOT NULL,
|
|
CONSTRAINT {c} PRIMARY KEY (doc_id))
|
|
DISTRIBUTE ON (doc_id) SORT ON (doc_id)
|
|
"""
|
|
).format(
|
|
t=t,
|
|
c=c,
|
|
)
|
|
)
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
t1 = sql.Identifier(*schema_prefix, self._table)
|
|
t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE)
|
|
c1 = sql.Identifier(
|
|
self._table + self.CONTENT_TABLE + "_pk_doc_id_embedding_id"
|
|
)
|
|
c2 = sql.Identifier(self._table + self.CONTENT_TABLE + "_fk_doc_id")
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS {t1} (
|
|
doc_id UUID NOT NULL,
|
|
embedding_id SMALLINT NOT NULL,
|
|
embedding FLOAT NOT NULL,
|
|
CONSTRAINT {c1} PRIMARY KEY (doc_id, embedding_id),
|
|
CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id))
|
|
DISTRIBUTE ON (doc_id) SORT ON (doc_id)
|
|
"""
|
|
).format(
|
|
t1=t1,
|
|
t2=t2,
|
|
c1=c1,
|
|
c2=c2,
|
|
)
|
|
)
|
|
|
|
def drop(
|
|
self,
|
|
table: str,
|
|
schema: Optional[str] = None,
|
|
cursor: Optional["PgCursor"] = None,
|
|
) -> None:
|
|
"""
|
|
Helper function: Drop data. If a cursor is provided, use it;
|
|
otherwise, obtain a new cursor for the operation.
|
|
"""
|
|
if cursor is None:
|
|
with self.connection.get_cursor() as cursor:
|
|
self._drop_table(cursor, table, schema=schema)
|
|
else:
|
|
self._drop_table(cursor, table, schema=schema)
|
|
|
|
def _drop_table(
|
|
self,
|
|
cursor: "PgCursor",
|
|
table: str,
|
|
schema: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Executes the drop table command using the given cursor.
|
|
"""
|
|
from psycopg2 import sql
|
|
|
|
if schema:
|
|
table_name = sql.Identifier(schema, table)
|
|
else:
|
|
table_name = sql.Identifier(table)
|
|
|
|
drop_table_query = sql.SQL(
|
|
"""
|
|
DROP TABLE IF EXISTS {} CASCADE
|
|
"""
|
|
).format(table_name)
|
|
cursor.execute(drop_table_query)
|
|
|
|
def _check_database_utf8(self) -> bool:
|
|
"""
|
|
Helper function: Test the database is UTF-8 encoded
|
|
"""
|
|
with self.connection.get_cursor() as cursor:
|
|
query = """
|
|
SELECT pg_encoding_to_char(encoding)
|
|
FROM pg_database
|
|
WHERE datname = current_database();
|
|
"""
|
|
cursor.execute(query)
|
|
encoding = cursor.fetchone()[0]
|
|
|
|
if encoding.lower() == "utf8" or encoding.lower() == "utf-8":
|
|
return True
|
|
else:
|
|
raise Exception("Database encoding is not UTF-8")
|
|
|
|
return False
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
batch_size = 10000
|
|
|
|
texts = list(texts)
|
|
embeddings = self._embedding.embed_documents(list(texts))
|
|
results = []
|
|
if not metadatas:
|
|
metadatas = [{} for _ in texts]
|
|
|
|
index_params = kwargs.get("index_params") or Yellowbrick.IndexParams()
|
|
|
|
with self.connection.get_cursor() as cursor:
|
|
content_io = StringIO()
|
|
embeddings_io = StringIO()
|
|
content_writer = csv.writer(
|
|
content_io, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL
|
|
)
|
|
embeddings_writer = csv.writer(
|
|
embeddings_io, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL
|
|
)
|
|
current_batch_size = 0
|
|
|
|
for i, text in enumerate(texts):
|
|
doc_uuid = str(uuid.uuid4())
|
|
results.append(doc_uuid)
|
|
|
|
content_writer.writerow([doc_uuid, text, json.dumps(metadatas[i])])
|
|
|
|
for embedding_id, embedding in enumerate(embeddings[i]):
|
|
embeddings_writer.writerow([doc_uuid, embedding_id, embedding])
|
|
|
|
current_batch_size += 1
|
|
|
|
if current_batch_size >= batch_size:
|
|
self._copy_to_db(cursor, content_io, embeddings_io)
|
|
|
|
content_io.seek(0)
|
|
content_io.truncate(0)
|
|
embeddings_io.seek(0)
|
|
embeddings_io.truncate(0)
|
|
current_batch_size = 0
|
|
|
|
if current_batch_size > 0:
|
|
self._copy_to_db(cursor, content_io, embeddings_io)
|
|
|
|
if index_params.index_type == Yellowbrick.IndexType.LSH:
|
|
self._update_index(index_params, uuid.UUID(doc_uuid))
|
|
|
|
return results
|
|
|
|
def _copy_to_db(
|
|
self, cursor: "PgCursor", content_io: StringIO, embeddings_io: StringIO
|
|
) -> None:
|
|
content_io.seek(0)
|
|
embeddings_io.seek(0)
|
|
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
table = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE)
|
|
content_copy_query = sql.SQL(
|
|
"""
|
|
COPY {table} (doc_id, text, metadata) FROM
|
|
STDIN WITH (FORMAT CSV, DELIMITER E'\\t', QUOTE '\"')
|
|
"""
|
|
).format(table=table)
|
|
cursor.copy_expert(content_copy_query, content_io)
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
table = sql.Identifier(*schema_prefix, self._table)
|
|
embeddings_copy_query = sql.SQL(
|
|
"""
|
|
COPY {table} (doc_id, embedding_id, embedding) FROM
|
|
STDIN WITH (FORMAT CSV, DELIMITER E'\\t', QUOTE '\"')
|
|
"""
|
|
).format(table=table)
|
|
cursor.copy_expert(embeddings_copy_query, embeddings_io)
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls: Type[Yellowbrick],
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
connection_string: str = "",
|
|
table: str = "langchain",
|
|
schema: str = "public",
|
|
drop: bool = False,
|
|
**kwargs: Any,
|
|
) -> Yellowbrick:
|
|
"""Add texts to the vectorstore index.
|
|
Args:
|
|
texts: Iterable of strings to add to the vectorstore.
|
|
metadatas: Optional list of metadatas associated with the texts.
|
|
connection_string: URI to Yellowbrick instance
|
|
embedding: Embedding function
|
|
table: table to store embeddings
|
|
kwargs: vectorstore specific parameters
|
|
"""
|
|
vss = cls(
|
|
embedding=embedding,
|
|
connection_string=connection_string,
|
|
table=table,
|
|
schema=schema,
|
|
drop=drop,
|
|
)
|
|
vss.add_texts(texts=texts, metadatas=metadatas, **kwargs)
|
|
return vss
|
|
|
|
def delete(
|
|
self,
|
|
ids: Optional[List[str]] = None,
|
|
delete_all: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Delete vectors by uuids.
|
|
|
|
Args:
|
|
ids: List of ids to delete, where each id is a uuid string.
|
|
"""
|
|
from psycopg2 import sql
|
|
|
|
if delete_all:
|
|
where_sql = sql.SQL(
|
|
"""
|
|
WHERE 1=1
|
|
"""
|
|
)
|
|
elif ids is not None:
|
|
uuids = tuple(sql.Literal(id) for id in ids)
|
|
ids_formatted = sql.SQL(", ").join(uuids)
|
|
where_sql = sql.SQL(
|
|
"""
|
|
WHERE doc_id IN ({ids})
|
|
"""
|
|
).format(
|
|
ids=ids_formatted,
|
|
)
|
|
else:
|
|
raise ValueError("Either ids or delete_all must be provided.")
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
with self.connection.get_cursor() as cursor:
|
|
table_identifier = sql.Identifier(
|
|
*schema_prefix, self._table + self.CONTENT_TABLE
|
|
)
|
|
query = sql.SQL("DELETE FROM {table} {where_sql}").format(
|
|
table=table_identifier, where_sql=where_sql
|
|
)
|
|
cursor.execute(query)
|
|
|
|
table_identifier = sql.Identifier(*schema_prefix, self._table)
|
|
query = sql.SQL("DELETE FROM {table} {where_sql}").format(
|
|
table=table_identifier, where_sql=where_sql
|
|
)
|
|
cursor.execute(query)
|
|
|
|
if self._table_exists(
|
|
cursor, self._table + self.LSH_INDEX_TABLE, *schema_prefix
|
|
):
|
|
table_identifier = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_INDEX_TABLE
|
|
)
|
|
query = sql.SQL("DELETE FROM {table} {where_sql}").format(
|
|
table=table_identifier, where_sql=where_sql
|
|
)
|
|
cursor.execute(query)
|
|
|
|
return None
|
|
|
|
def _table_exists(
|
|
self, cursor: "PgCursor", table_name: str, schema: str = "public"
|
|
) -> bool:
|
|
"""
|
|
Checks if a table exists in the given schema
|
|
"""
|
|
from psycopg2 import sql
|
|
|
|
schema = sql.Literal(schema)
|
|
table_name = sql.Literal(table_name)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
SELECT COUNT(*)
|
|
FROM sys.table t INNER JOIN sys.schema s ON t.schema_id = s.schema_id
|
|
WHERE s.name = {schema} AND t.name = {table_name}
|
|
"""
|
|
).format(
|
|
schema=schema,
|
|
table_name=table_name,
|
|
)
|
|
)
|
|
return cursor.fetchone()[0] > 0
|
|
|
|
def _generate_vector_uuid(self, vector: List[float]) -> uuid.UUID:
|
|
import hashlib
|
|
|
|
vector_str = ",".join(map(str, vector))
|
|
hash_object = hashlib.sha1(vector_str.encode())
|
|
hash_digest = hash_object.digest()
|
|
vector_uuid = uuid.UUID(bytes=hash_digest[:16])
|
|
return vector_uuid
|
|
|
|
def similarity_search_with_score_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a similarity search with Yellowbrick with vector
|
|
|
|
Args:
|
|
embedding (List[float]): query embedding
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
|
|
NOTE: Please do not let end-user fill this and always be aware
|
|
of SQL injection.
|
|
|
|
Returns:
|
|
List[Document, float]: List of Documents and scores
|
|
"""
|
|
from psycopg2 import sql
|
|
from psycopg2.extras import execute_values
|
|
|
|
index_params = kwargs.get("index_params") or Yellowbrick.IndexParams()
|
|
|
|
with self.connection.get_cursor() as cursor:
|
|
tmp_embeddings_table = "tmp_" + self._table
|
|
tmp_doc_id = self._generate_vector_uuid(embedding)
|
|
create_table_query = sql.SQL(
|
|
"""
|
|
CREATE TEMPORARY TABLE {} (
|
|
doc_id UUID,
|
|
embedding_id SMALLINT,
|
|
embedding FLOAT)
|
|
ON COMMIT DROP
|
|
DISTRIBUTE REPLICATE
|
|
"""
|
|
).format(sql.Identifier(tmp_embeddings_table))
|
|
cursor.execute(create_table_query)
|
|
data_input = [
|
|
(str(tmp_doc_id), embedding_id, embedding_value)
|
|
for embedding_id, embedding_value in enumerate(embedding)
|
|
]
|
|
insert_query = sql.SQL(
|
|
"INSERT INTO {} (doc_id, embedding_id, embedding) VALUES %s"
|
|
).format(sql.Identifier(tmp_embeddings_table))
|
|
execute_values(cursor, insert_query, data_input)
|
|
|
|
v1 = sql.Identifier(tmp_embeddings_table)
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
v2 = sql.Identifier(*schema_prefix, self._table)
|
|
v3 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE)
|
|
if index_params.index_type == Yellowbrick.IndexType.LSH:
|
|
tmp_hash_table = self._table + "_tmp_hash"
|
|
self._generate_tmp_lsh_hashes(
|
|
cursor,
|
|
tmp_embeddings_table,
|
|
tmp_hash_table,
|
|
)
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
lsh_index = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_INDEX_TABLE
|
|
)
|
|
input_hash_table = sql.Identifier(tmp_hash_table)
|
|
sql_query = sql.SQL(
|
|
"""
|
|
WITH index_docs AS (
|
|
SELECT
|
|
t1.doc_id,
|
|
SUM(ABS(t1.hash-t2.hash)) as hamming_distance
|
|
FROM
|
|
{lsh_index} t1
|
|
INNER JOIN
|
|
{input_hash_table} t2
|
|
ON t1.hash_index = t2.hash_index
|
|
GROUP BY t1.doc_id
|
|
HAVING hamming_distance <= {hamming_distance}
|
|
)
|
|
SELECT
|
|
text,
|
|
metadata,
|
|
SUM(v1.embedding * v2.embedding) /
|
|
(SQRT(SUM(v1.embedding * v1.embedding)) *
|
|
SQRT(SUM(v2.embedding * v2.embedding))) AS score
|
|
FROM
|
|
{v1} v1
|
|
INNER JOIN
|
|
{v2} v2
|
|
ON v1.embedding_id = v2.embedding_id
|
|
INNER JOIN
|
|
{v3} v3
|
|
ON v2.doc_id = v3.doc_id
|
|
INNER JOIN
|
|
index_docs v4
|
|
ON v2.doc_id = v4.doc_id
|
|
GROUP BY v3.doc_id, v3.text, v3.metadata
|
|
ORDER BY score DESC
|
|
LIMIT %s
|
|
"""
|
|
).format(
|
|
v1=v1,
|
|
v2=v2,
|
|
v3=v3,
|
|
lsh_index=lsh_index,
|
|
input_hash_table=input_hash_table,
|
|
hamming_distance=sql.Literal(
|
|
index_params.get_param("hamming_distance", 0)
|
|
),
|
|
)
|
|
cursor.execute(
|
|
sql_query,
|
|
(k,),
|
|
)
|
|
results = cursor.fetchall()
|
|
else:
|
|
sql_query = sql.SQL(
|
|
"""
|
|
SELECT
|
|
text,
|
|
metadata,
|
|
score
|
|
FROM
|
|
(SELECT
|
|
v2.doc_id doc_id,
|
|
SUM(v1.embedding * v2.embedding) /
|
|
(SQRT(SUM(v1.embedding * v1.embedding)) *
|
|
SQRT(SUM(v2.embedding * v2.embedding))) AS score
|
|
FROM
|
|
{v1} v1
|
|
INNER JOIN
|
|
{v2} v2
|
|
ON v1.embedding_id = v2.embedding_id
|
|
GROUP BY v2.doc_id
|
|
ORDER BY score DESC LIMIT %s
|
|
) v4
|
|
INNER JOIN
|
|
{v3} v3
|
|
ON v4.doc_id = v3.doc_id
|
|
ORDER BY score DESC
|
|
"""
|
|
).format(
|
|
v1=v1,
|
|
v2=v2,
|
|
v3=v3,
|
|
)
|
|
cursor.execute(sql_query, (k,))
|
|
results = cursor.fetchall()
|
|
|
|
documents: List[Tuple[Document, float]] = []
|
|
for result in results:
|
|
metadata = json.loads(result[1]) or {}
|
|
doc = Document(page_content=result[0], metadata=metadata)
|
|
documents.append((doc, result[2]))
|
|
|
|
return documents
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Perform a similarity search with Yellowbrick
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
|
|
NOTE: Please do not let end-user fill this and always be aware
|
|
of SQL injection.
|
|
|
|
Returns:
|
|
List[Document]: List of Documents
|
|
"""
|
|
embedding = self._embedding.embed_query(query)
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k, **kwargs
|
|
)
|
|
return [doc for doc, _ in documents]
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a similarity search with Yellowbrick
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
|
|
NOTE: Please do not let end-user fill this and always be aware
|
|
of SQL injection.
|
|
|
|
Returns:
|
|
List[Document]: List of (Document, similarity)
|
|
"""
|
|
embedding = self._embedding.embed_query(query)
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k, **kwargs
|
|
)
|
|
return documents
|
|
|
|
def similarity_search_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Perform a similarity search with Yellowbrick by vectors
|
|
|
|
Args:
|
|
embedding (List[float]): query embedding
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
|
|
NOTE: Please do not let end-user fill this and always be aware
|
|
of SQL injection.
|
|
|
|
Returns:
|
|
List[Document]: List of documents
|
|
"""
|
|
documents = self.similarity_search_with_score_by_vector(
|
|
embedding=embedding, k=k, **kwargs
|
|
)
|
|
return [doc for doc, _ in documents]
|
|
|
|
def _update_lsh_hashes(
|
|
self,
|
|
cursor: "PgCursor",
|
|
doc_id: Optional[uuid.UUID] = None,
|
|
) -> None:
|
|
"""Add hashes to LSH index"""
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
lsh_hyperplane_table = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE
|
|
)
|
|
lsh_index_table_id = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_INDEX_TABLE
|
|
)
|
|
embedding_table_id = sql.Identifier(*schema_prefix, self._table)
|
|
query_prefix_id = sql.SQL("INSERT INTO {}").format(lsh_index_table_id)
|
|
condition = (
|
|
sql.SQL("WHERE e.doc_id = {doc_id}").format(doc_id=sql.Literal(str(doc_id)))
|
|
if doc_id
|
|
else sql.SQL("")
|
|
)
|
|
group_by = sql.SQL("GROUP BY 1, 2")
|
|
|
|
input_query = sql.SQL(
|
|
"""
|
|
{query_prefix}
|
|
SELECT
|
|
e.doc_id as doc_id,
|
|
h.id as hash_index,
|
|
CASE WHEN SUM(e.embedding * h.hyperplane) > 0 THEN 1 ELSE 0 END as hash
|
|
FROM {embedding_table} e
|
|
INNER JOIN {hyperplanes} h ON e.embedding_id = h.hyperplane_id
|
|
{condition}
|
|
{group_by}
|
|
"""
|
|
).format(
|
|
query_prefix=query_prefix_id,
|
|
embedding_table=embedding_table_id,
|
|
hyperplanes=lsh_hyperplane_table,
|
|
condition=condition,
|
|
group_by=group_by,
|
|
)
|
|
cursor.execute(input_query)
|
|
|
|
def _generate_tmp_lsh_hashes(
|
|
self, cursor: "PgCursor", tmp_embedding_table: str, tmp_hash_table: str
|
|
) -> None:
|
|
"""Generate temp LSH"""
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
lsh_hyperplane_table = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE
|
|
)
|
|
tmp_embedding_table_id = sql.Identifier(tmp_embedding_table)
|
|
tmp_hash_table_id = sql.Identifier(tmp_hash_table)
|
|
query_prefix = sql.SQL("CREATE TEMPORARY TABLE {} ON COMMIT DROP AS").format(
|
|
tmp_hash_table_id
|
|
)
|
|
group_by = sql.SQL("GROUP BY 1")
|
|
|
|
input_query = sql.SQL(
|
|
"""
|
|
{query_prefix}
|
|
SELECT
|
|
h.id as hash_index,
|
|
CASE WHEN SUM(e.embedding * h.hyperplane) > 0 THEN 1 ELSE 0 END as hash
|
|
FROM {embedding_table} e
|
|
INNER JOIN {hyperplanes} h ON e.embedding_id = h.hyperplane_id
|
|
{group_by}
|
|
DISTRIBUTE REPLICATE
|
|
"""
|
|
).format(
|
|
query_prefix=query_prefix,
|
|
embedding_table=tmp_embedding_table_id,
|
|
hyperplanes=lsh_hyperplane_table,
|
|
group_by=group_by,
|
|
)
|
|
cursor.execute(input_query)
|
|
|
|
def _populate_hyperplanes(self, cursor: "PgCursor", num_hyperplanes: int) -> None:
|
|
"""Generate random hyperplanes and store in Yellowbrick"""
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
hyperplanes_table = sql.Identifier(
|
|
*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE
|
|
)
|
|
cursor.execute(sql.SQL("SELECT COUNT(*) FROM {t}").format(t=hyperplanes_table))
|
|
if cursor.fetchone()[0] > 0:
|
|
return
|
|
|
|
t = sql.Identifier(*schema_prefix, self._table)
|
|
cursor.execute(sql.SQL("SELECT MAX(embedding_id) FROM {t}").format(t=t))
|
|
num_dimensions = cursor.fetchone()[0]
|
|
num_dimensions += 1
|
|
|
|
insert_query = sql.SQL(
|
|
"""
|
|
WITH parameters AS (
|
|
SELECT {num_hyperplanes} AS num_hyperplanes,
|
|
{dims_per_hyperplane} AS dims_per_hyperplane
|
|
)
|
|
INSERT INTO {hyperplanes_table} (id, hyperplane_id, hyperplane)
|
|
SELECT id, hyperplane_id, (random() * 2 - 1) AS hyperplane
|
|
FROM
|
|
(SELECT range-1 id FROM sys.rowgenerator
|
|
WHERE range BETWEEN 1 AND
|
|
(SELECT num_hyperplanes FROM parameters) AND
|
|
worker_lid = 0 AND thread_id = 0) a,
|
|
(SELECT range-1 hyperplane_id FROM sys.rowgenerator
|
|
WHERE range BETWEEN 1 AND
|
|
(SELECT dims_per_hyperplane FROM parameters) AND
|
|
worker_lid = 0 AND thread_id = 0) b
|
|
"""
|
|
).format(
|
|
num_hyperplanes=sql.Literal(num_hyperplanes),
|
|
dims_per_hyperplane=sql.Literal(num_dimensions),
|
|
hyperplanes_table=hyperplanes_table,
|
|
)
|
|
cursor.execute(insert_query)
|
|
|
|
def _create_lsh_index_tables(self, cursor: "PgCursor") -> None:
|
|
"""Create LSH index and hyperplane tables"""
|
|
from psycopg2 import sql
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
t1 = sql.Identifier(*schema_prefix, self._table + self.LSH_INDEX_TABLE)
|
|
t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE)
|
|
c1 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_pk_doc_id")
|
|
c2 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_fk_doc_id")
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS {t1} (
|
|
doc_id UUID NOT NULL,
|
|
hash_index SMALLINT NOT NULL,
|
|
hash SMALLINT NOT NULL,
|
|
CONSTRAINT {c1} PRIMARY KEY (doc_id, hash_index),
|
|
CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id))
|
|
DISTRIBUTE ON (doc_id) SORT ON (doc_id)
|
|
"""
|
|
).format(
|
|
t1=t1,
|
|
t2=t2,
|
|
c1=c1,
|
|
c2=c2,
|
|
)
|
|
)
|
|
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
t = sql.Identifier(*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE)
|
|
c = sql.Identifier(self._table + self.LSH_HYPERPLANE_TABLE + "_pk_id_hp_id")
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS {t} (
|
|
id SMALLINT NOT NULL,
|
|
hyperplane_id SMALLINT NOT NULL,
|
|
hyperplane FLOAT NOT NULL,
|
|
CONSTRAINT {c} PRIMARY KEY (id, hyperplane_id))
|
|
DISTRIBUTE REPLICATE SORT ON (id)
|
|
"""
|
|
).format(
|
|
t=t,
|
|
c=c,
|
|
)
|
|
)
|
|
|
|
def _drop_lsh_index_tables(self, cursor: "PgCursor") -> None:
|
|
"""Drop LSH index tables"""
|
|
self.drop(
|
|
schema=self._schema, table=self._table + self.LSH_INDEX_TABLE, cursor=cursor
|
|
)
|
|
self.drop(
|
|
schema=self._schema,
|
|
table=self._table + self.LSH_HYPERPLANE_TABLE,
|
|
cursor=cursor,
|
|
)
|
|
|
|
def create_index(self, index_params: Yellowbrick.IndexParams) -> None:
|
|
"""Create index from existing vectors"""
|
|
if index_params.index_type == Yellowbrick.IndexType.LSH:
|
|
with self.connection.get_cursor() as cursor:
|
|
self._drop_lsh_index_tables(cursor)
|
|
self._create_lsh_index_tables(cursor)
|
|
self._populate_hyperplanes(
|
|
cursor, index_params.get_param("num_hyperplanes", 128)
|
|
)
|
|
self._update_lsh_hashes(cursor)
|
|
|
|
def drop_index(self, index_params: Yellowbrick.IndexParams) -> None:
|
|
"""Drop an index"""
|
|
if index_params.index_type == Yellowbrick.IndexType.LSH:
|
|
with self.connection.get_cursor() as cursor:
|
|
self._drop_lsh_index_tables(cursor)
|
|
|
|
def _update_index(
|
|
self, index_params: Yellowbrick.IndexParams, doc_id: uuid.UUID
|
|
) -> None:
|
|
"""Update an index with a new or modified embedding in the embeddings table"""
|
|
if index_params.index_type == Yellowbrick.IndexType.LSH:
|
|
with self.connection.get_cursor() as cursor:
|
|
self._update_lsh_hashes(cursor, doc_id)
|
|
|
|
def migrate_schema_v1_to_v2(self) -> None:
|
|
from psycopg2 import sql
|
|
|
|
try:
|
|
with self.connection.get_cursor() as cursor:
|
|
schema_prefix = (self._schema,) if self._schema else ()
|
|
embeddings = sql.Identifier(*schema_prefix, self._table)
|
|
old_embeddings = sql.Identifier(*schema_prefix, self._table + "_v1")
|
|
content = sql.Identifier(
|
|
*schema_prefix, self._table + self.CONTENT_TABLE
|
|
)
|
|
alter_table_query = sql.SQL("ALTER TABLE {t1} RENAME TO {t2}").format(
|
|
t1=embeddings,
|
|
t2=old_embeddings,
|
|
)
|
|
cursor.execute(alter_table_query)
|
|
|
|
self._create_table(cursor)
|
|
|
|
insert_query = sql.SQL(
|
|
"""
|
|
INSERT INTO {t1} (doc_id, embedding_id, embedding)
|
|
SELECT id, embedding_id, embedding FROM {t2}
|
|
"""
|
|
).format(
|
|
t1=embeddings,
|
|
t2=old_embeddings,
|
|
)
|
|
cursor.execute(insert_query)
|
|
|
|
insert_content_query = sql.SQL(
|
|
"""
|
|
INSERT INTO {t1} (doc_id, text, metadata)
|
|
SELECT DISTINCT id, text, metadata FROM {t2}
|
|
"""
|
|
).format(t1=content, t2=old_embeddings)
|
|
cursor.execute(insert_content_query)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to migrate schema: {e}") from e
|