mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
32ec56194b
Now the SQL used to delete vector doc from myscale is as follow: ```sql DELETE FROM collection WHERE id = '1' AND id = '2' AND id = '3' ``` But the expected one should be ```sql DELETE FROM collection WHERE id IN ('1', '2', '3') ```
616 lines
22 KiB
Python
616 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from hashlib import sha1
|
|
from threading import Thread
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import BaseSettings
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def has_mul_sub_str(s: str, *args: Any) -> bool:
|
|
"""
|
|
Check if a string contains multiple substrings.
|
|
Args:
|
|
s: string to check.
|
|
*args: substrings to check.
|
|
|
|
Returns:
|
|
True if all substrings are in the string, False otherwise.
|
|
"""
|
|
for a in args:
|
|
if a not in s:
|
|
return False
|
|
return True
|
|
|
|
|
|
class MyScaleSettings(BaseSettings):
|
|
"""MyScale client configuration.
|
|
|
|
Attribute:
|
|
myscale_host (str) : An URL to connect to MyScale backend.
|
|
Defaults to 'localhost'.
|
|
myscale_port (int) : URL port to connect with HTTP. Defaults to 8443.
|
|
username (str) : Username to login. Defaults to None.
|
|
password (str) : Password to login. Defaults to None.
|
|
index_type (str): index type string.
|
|
index_param (dict): index build parameter.
|
|
database (str) : Database name to find the table. Defaults to 'default'.
|
|
table (str) : Table name to operate on.
|
|
Defaults to 'vector_table'.
|
|
metric (str) : Metric to compute distance,
|
|
supported are ('L2', 'Cosine', 'IP'). Defaults to 'Cosine'.
|
|
column_map (Dict) : Column type map to project column name onto langchain
|
|
semantics. Must have keys: `text`, `id`, `vector`,
|
|
must be same size to number of columns. For example:
|
|
.. code-block:: python
|
|
|
|
{
|
|
'id': 'text_id',
|
|
'vector': 'text_embedding',
|
|
'text': 'text_plain',
|
|
'metadata': 'metadata_dictionary_in_json',
|
|
}
|
|
|
|
Defaults to identity map.
|
|
|
|
"""
|
|
|
|
host: str = "localhost"
|
|
port: int = 8443
|
|
|
|
username: Optional[str] = None
|
|
password: Optional[str] = None
|
|
|
|
index_type: str = "MSTG"
|
|
index_param: Optional[Dict[str, str]] = None
|
|
|
|
column_map: Dict[str, str] = {
|
|
"id": "id",
|
|
"text": "text",
|
|
"vector": "vector",
|
|
"metadata": "metadata",
|
|
}
|
|
|
|
database: str = "default"
|
|
table: str = "langchain"
|
|
metric: str = "Cosine"
|
|
|
|
def __getitem__(self, item: str) -> Any:
|
|
return getattr(self, item)
|
|
|
|
class Config:
|
|
env_file = ".env"
|
|
env_prefix = "myscale_"
|
|
env_file_encoding = "utf-8"
|
|
|
|
|
|
class MyScale(VectorStore):
|
|
"""`MyScale` vector store.
|
|
|
|
You need a `clickhouse-connect` python package, and a valid account
|
|
to connect to MyScale.
|
|
|
|
MyScale can not only search with simple vector indexes.
|
|
It also supports a complex query with multiple conditions,
|
|
constraints and even sub-queries.
|
|
|
|
For more information, please visit
|
|
[myscale official site](https://docs.myscale.com/en/overview/)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
config: Optional[MyScaleSettings] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""MyScale Wrapper to LangChain
|
|
|
|
embedding (Embeddings):
|
|
config (MyScaleSettings): Configuration to MyScale Client
|
|
Other keyword arguments will pass into
|
|
[clickhouse-connect](https://docs.myscale.com/)
|
|
"""
|
|
try:
|
|
from clickhouse_connect import get_client
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import clickhouse connect python package. "
|
|
"Please install it with `pip install clickhouse-connect`."
|
|
)
|
|
try:
|
|
from tqdm import tqdm
|
|
|
|
self.pgbar = tqdm
|
|
except ImportError:
|
|
# Just in case if tqdm is not installed
|
|
self.pgbar = lambda x: x
|
|
super().__init__()
|
|
if config is not None:
|
|
self.config = config
|
|
else:
|
|
self.config = MyScaleSettings()
|
|
assert self.config
|
|
assert self.config.host and self.config.port
|
|
assert (
|
|
self.config.column_map
|
|
and self.config.database
|
|
and self.config.table
|
|
and self.config.metric
|
|
)
|
|
for k in ["id", "vector", "text", "metadata"]:
|
|
assert k in self.config.column_map
|
|
assert self.config.metric.upper() in ["IP", "COSINE", "L2"]
|
|
if self.config.metric in ["ip", "cosine", "l2"]:
|
|
logger.warning(
|
|
"Lower case metric types will be deprecated "
|
|
"the future. Please use one of ('IP', 'Cosine', 'L2')"
|
|
)
|
|
|
|
# initialize the schema
|
|
dim = len(embedding.embed_query("try this out"))
|
|
|
|
index_params = (
|
|
", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()])
|
|
if self.config.index_param
|
|
else ""
|
|
)
|
|
schema_ = f"""
|
|
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|
{self.config.column_map['id']} String,
|
|
{self.config.column_map['text']} String,
|
|
{self.config.column_map['vector']} Array(Float32),
|
|
{self.config.column_map['metadata']} JSON,
|
|
CONSTRAINT cons_vec_len CHECK length(\
|
|
{self.config.column_map['vector']}) = {dim},
|
|
VECTOR INDEX vidx {self.config.column_map['vector']} \
|
|
TYPE {self.config.index_type}(\
|
|
'metric_type={self.config.metric}'{index_params})
|
|
) ENGINE = MergeTree ORDER BY {self.config.column_map['id']}
|
|
"""
|
|
self.dim = dim
|
|
self.BS = "\\"
|
|
self.must_escape = ("\\", "'")
|
|
self._embeddings = embedding
|
|
self.dist_order = (
|
|
"ASC" if self.config.metric.upper() in ["COSINE", "L2"] else "DESC"
|
|
)
|
|
|
|
# Create a connection to myscale
|
|
self.client = get_client(
|
|
host=self.config.host,
|
|
port=self.config.port,
|
|
username=self.config.username,
|
|
password=self.config.password,
|
|
**kwargs,
|
|
)
|
|
self.client.command("SET allow_experimental_object_type=1")
|
|
self.client.command(schema_)
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
return self._embeddings
|
|
|
|
def escape_str(self, value: str) -> str:
|
|
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
|
|
|
|
def _build_istr(self, transac: Iterable, column_names: Iterable[str]) -> str:
|
|
ks = ",".join(column_names)
|
|
_data = []
|
|
for n in transac:
|
|
n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n])
|
|
_data.append(f"({n})")
|
|
i_str = f"""
|
|
INSERT INTO TABLE
|
|
{self.config.database}.{self.config.table}({ks})
|
|
VALUES
|
|
{','.join(_data)}
|
|
"""
|
|
return i_str
|
|
|
|
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
|
|
_i_str = self._build_istr(transac, column_names)
|
|
self.client.command(_i_str)
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
batch_size: int = 32,
|
|
ids: Optional[Iterable[str]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Run more texts through the embeddings and add to the vectorstore.
|
|
|
|
Args:
|
|
texts: Iterable of strings to add to the vectorstore.
|
|
ids: Optional list of ids to associate with the texts.
|
|
batch_size: Batch size of insertion
|
|
metadata: Optional column data to be inserted
|
|
|
|
Returns:
|
|
List of ids from adding the texts into the vectorstore.
|
|
|
|
"""
|
|
# Embed and create the documents
|
|
ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
|
|
colmap_ = self.config.column_map
|
|
|
|
transac = []
|
|
column_names = {
|
|
colmap_["id"]: ids,
|
|
colmap_["text"]: texts,
|
|
colmap_["vector"]: map(self._embeddings.embed_query, texts),
|
|
}
|
|
metadatas = metadatas or [{} for _ in texts]
|
|
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
|
|
assert len(set(colmap_) - set(column_names)) >= 0
|
|
keys, values = zip(*column_names.items())
|
|
try:
|
|
t = None
|
|
for v in self.pgbar(
|
|
zip(*values), desc="Inserting data...", total=len(metadatas)
|
|
):
|
|
assert len(v[keys.index(self.config.column_map["vector"])]) == self.dim
|
|
transac.append(v)
|
|
if len(transac) == batch_size:
|
|
if t:
|
|
t.join()
|
|
t = Thread(target=self._insert, args=[transac, keys])
|
|
t.start()
|
|
transac = []
|
|
if len(transac) > 0:
|
|
if t:
|
|
t.join()
|
|
self._insert(transac, keys)
|
|
return [i for i in ids]
|
|
except Exception as e:
|
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
|
return []
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: Iterable[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
|
config: Optional[MyScaleSettings] = None,
|
|
text_ids: Optional[Iterable[str]] = None,
|
|
batch_size: int = 32,
|
|
**kwargs: Any,
|
|
) -> MyScale:
|
|
"""Create Myscale wrapper with existing texts
|
|
|
|
Args:
|
|
texts (Iterable[str]): List or tuple of strings to be added
|
|
embedding (Embeddings): Function to extract text embedding
|
|
config (MyScaleSettings, Optional): Myscale configuration
|
|
text_ids (Optional[Iterable], optional): IDs for the texts.
|
|
Defaults to None.
|
|
batch_size (int, optional): Batchsize when transmitting data to MyScale.
|
|
Defaults to 32.
|
|
metadata (List[dict], optional): metadata to texts. Defaults to None.
|
|
Other keyword arguments will pass into
|
|
[clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api)
|
|
Returns:
|
|
MyScale Index
|
|
"""
|
|
ctx = cls(embedding, config, **kwargs)
|
|
ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
|
|
return ctx
|
|
|
|
def __repr__(self) -> str:
|
|
"""Text representation for myscale, prints backends, username and schemas.
|
|
Easy to use with `str(Myscale())`
|
|
|
|
Returns:
|
|
repr: string to show connection info and data schema
|
|
"""
|
|
_repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
|
|
_repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
|
|
_repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
|
|
_repr += "-" * 51 + "\n"
|
|
for r in self.client.query(
|
|
f"DESC {self.config.database}.{self.config.table}"
|
|
).named_results():
|
|
_repr += (
|
|
f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n"
|
|
)
|
|
_repr += "-" * 51 + "\n"
|
|
return _repr
|
|
|
|
def _build_qstr(
|
|
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
|
) -> str:
|
|
q_emb_str = ",".join(map(str, q_emb))
|
|
if where_str:
|
|
where_str = f"PREWHERE {where_str}"
|
|
else:
|
|
where_str = ""
|
|
|
|
q_str = f"""
|
|
SELECT {self.config.column_map['text']},
|
|
{self.config.column_map['metadata']}, dist
|
|
FROM {self.config.database}.{self.config.table}
|
|
{where_str}
|
|
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
|
|
AS dist {self.dist_order}
|
|
LIMIT {topk}
|
|
"""
|
|
return q_str
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Perform a similarity search with MyScale
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): where condition string.
|
|
Defaults to None.
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection. When dealing with metadatas, remember to
|
|
use `{self.metadata_column}.attribute` instead of `attribute`
|
|
alone. The default name for it is `metadata`.
|
|
|
|
Returns:
|
|
List[Document]: List of Documents
|
|
"""
|
|
return self.similarity_search_by_vector(
|
|
self._embeddings.embed_query(query), k, where_str, **kwargs
|
|
)
|
|
|
|
def similarity_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a similarity search with MyScale by vectors
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): where condition string.
|
|
Defaults to None.
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection. When dealing with metadatas, remember to
|
|
use `{self.metadata_column}.attribute` instead of `attribute`
|
|
alone. The default name for it is `metadata`.
|
|
|
|
Returns:
|
|
List[Document]: List of (Document, similarity)
|
|
"""
|
|
q_str = self._build_qstr(embedding, k, where_str)
|
|
try:
|
|
return [
|
|
Document(
|
|
page_content=r[self.config.column_map["text"]],
|
|
metadata=r[self.config.column_map["metadata"]],
|
|
)
|
|
for r in self.client.query(q_str).named_results()
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
|
return []
|
|
|
|
def similarity_search_with_relevance_scores(
|
|
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a similarity search with MyScale
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): where condition string.
|
|
Defaults to None.
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection. When dealing with metadatas, remember to
|
|
use `{self.metadata_column}.attribute` instead of `attribute`
|
|
alone. The default name for it is `metadata`.
|
|
|
|
Returns:
|
|
List[Document]: List of documents most similar to the query text
|
|
and cosine distance in float for each.
|
|
Lower score represents more similarity.
|
|
"""
|
|
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
|
|
try:
|
|
return [
|
|
(
|
|
Document(
|
|
page_content=r[self.config.column_map["text"]],
|
|
metadata=r[self.config.column_map["metadata"]],
|
|
),
|
|
r["dist"],
|
|
)
|
|
for r in self.client.query(q_str).named_results()
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
|
return []
|
|
|
|
def drop(self) -> None:
|
|
"""
|
|
Helper function: Drop data
|
|
"""
|
|
self.client.command(
|
|
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
ids: Optional[List[str]] = None,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Optional[bool]:
|
|
"""Delete by vector ID or other criteria.
|
|
|
|
Args:
|
|
ids: List of ids to delete.
|
|
**kwargs: Other keyword arguments that subclasses might use.
|
|
|
|
Returns:
|
|
Optional[bool]: True if deletion is successful,
|
|
False otherwise, None if not implemented.
|
|
"""
|
|
assert not (
|
|
ids is None and where_str is None
|
|
), "You need to specify where to be deleted! Either with `ids` or `where_str`"
|
|
conds = []
|
|
if ids and len(ids) > 0:
|
|
id_list = ", ".join([f"'{id}'" for id in ids])
|
|
conds.append(f"{self.config.column_map['id']} IN ({id_list})")
|
|
if where_str:
|
|
conds.append(where_str)
|
|
assert len(conds) > 0
|
|
where_str_final = " AND ".join(conds)
|
|
qstr = (
|
|
f"DELETE FROM {self.config.database}.{self.config.table} "
|
|
f"WHERE {where_str_final}"
|
|
)
|
|
try:
|
|
self.client.command(qstr)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(str(e))
|
|
return False
|
|
|
|
@property
|
|
def metadata_column(self) -> str:
|
|
return self.config.column_map["metadata"]
|
|
|
|
|
|
class MyScaleWithoutJSON(MyScale):
|
|
"""MyScale vector store without metadata column
|
|
|
|
This is super handy if you are working to a SQL-native table
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
config: Optional[MyScaleSettings] = None,
|
|
must_have_cols: List[str] = [],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Building a myscale vector store without metadata column
|
|
|
|
embedding (Embeddings): embedding model
|
|
config (MyScaleSettings): Configuration to MyScale Client
|
|
must_have_cols (List[str]): column names to be included in query
|
|
Other keyword arguments will pass into
|
|
[clickhouse-connect](https://docs.myscale.com/)
|
|
"""
|
|
super().__init__(embedding, config, **kwargs)
|
|
self.must_have_cols: List[str] = must_have_cols
|
|
|
|
def _build_qstr(
|
|
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
|
) -> str:
|
|
q_emb_str = ",".join(map(str, q_emb))
|
|
if where_str:
|
|
where_str = f"PREWHERE {where_str}"
|
|
else:
|
|
where_str = ""
|
|
|
|
q_str = f"""
|
|
SELECT {self.config.column_map['text']}, dist,
|
|
{','.join(self.must_have_cols)}
|
|
FROM {self.config.database}.{self.config.table}
|
|
{where_str}
|
|
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
|
|
AS dist {self.dist_order}
|
|
LIMIT {topk}
|
|
"""
|
|
return q_str
|
|
|
|
def similarity_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a similarity search with MyScale by vectors
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): where condition string.
|
|
Defaults to None.
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection. When dealing with metadatas, remember to
|
|
use `{self.metadata_column}.attribute` instead of `attribute`
|
|
alone. The default name for it is `metadata`.
|
|
|
|
Returns:
|
|
List[Document]: List of (Document, similarity)
|
|
"""
|
|
q_str = self._build_qstr(embedding, k, where_str)
|
|
try:
|
|
return [
|
|
Document(
|
|
page_content=r[self.config.column_map["text"]],
|
|
metadata={k: r[k] for k in self.must_have_cols},
|
|
)
|
|
for r in self.client.query(q_str).named_results()
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
|
return []
|
|
|
|
def similarity_search_with_relevance_scores(
|
|
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a similarity search with MyScale
|
|
|
|
Args:
|
|
query (str): query string
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): where condition string.
|
|
Defaults to None.
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection. When dealing with metadatas, remember to
|
|
use `{self.metadata_column}.attribute` instead of `attribute`
|
|
alone. The default name for it is `metadata`.
|
|
|
|
Returns:
|
|
List[Document]: List of documents most similar to the query text
|
|
and cosine distance in float for each.
|
|
Lower score represents more similarity.
|
|
"""
|
|
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
|
|
try:
|
|
return [
|
|
(
|
|
Document(
|
|
page_content=r[self.config.column_map["text"]],
|
|
metadata={k: r[k] for k in self.must_have_cols},
|
|
),
|
|
r["dist"],
|
|
)
|
|
for r in self.client.query(q_str).named_results()
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
|
return []
|
|
|
|
@property
|
|
def metadata_column(self) -> str:
|
|
return ""
|