You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/vectorstores/apache_doris.py

481 lines
17 KiB
Python

from __future__ import annotations
import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseSettings
from langchain_core.vectorstores import VectorStore
logger = logging.getLogger()
DEBUG = False
class ApacheDorisSettings(BaseSettings):
"""Apache Doris client configuration.
Attributes:
apache_doris_host (str) : An URL to connect to frontend.
Defaults to 'localhost'.
apache_doris_port (int) : URL port to connect with HTTP. Defaults to 9030.
username (str) : Username to login. Defaults to 'root'.
password (str) : Password to login. Defaults to None.
database (str) : Database name to find the table. Defaults to 'default'.
table (str) : Table name to operate on.
Defaults to 'langchain'.
column_map (Dict) : Column type map to project column name onto langchain
semantics. Must have keys: `text`, `id`, `vector`,
must be same size to number of columns. For example:
.. code-block:: python
{
'id': 'text_id',
'embedding': 'text_embedding',
'document': 'text_plain',
'metadata': 'metadata_dictionary_in_json',
}
Defaults to identity map.
"""
host: str = "localhost"
port: int = 9030
username: str = "root"
password: str = ""
column_map: Dict[str, str] = {
"id": "id",
"document": "document",
"embedding": "embedding",
"metadata": "metadata",
}
database: str = "default"
table: str = "langchain"
def __getitem__(self, item: str) -> Any:
return getattr(self, item)
class Config:
env_file = ".env"
env_prefix = "apache_doris_"
env_file_encoding = "utf-8"
class ApacheDoris(VectorStore):
"""`Apache Doris` vector store.
You need a `pymysql` python package, and a valid account
to connect to Apache Doris.
For more information, please visit
[Apache Doris official site](https://doris.apache.org/)
[Apache Doris github](https://github.com/apache/doris)
"""
def __init__(
self,
embedding: Embeddings,
*,
config: Optional[ApacheDorisSettings] = None,
**kwargs: Any,
) -> None:
"""Constructor for Apache Doris.
Args:
embedding (Embeddings): Text embedding model.
config (ApacheDorisSettings): Apache Doris client configuration information.
"""
try:
import pymysql # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import pymysql python package. "
"Please install it with `pip install pymysql`."
)
try:
from tqdm import tqdm
self.pgbar = tqdm
except ImportError:
# Just in case if tqdm is not installed
self.pgbar = lambda x, **kwargs: x
super().__init__()
if config is not None:
self.config = config
else:
self.config = ApacheDorisSettings()
assert self.config
assert self.config.host and self.config.port
assert self.config.column_map and self.config.database and self.config.table
for k in ["id", "embedding", "document", "metadata"]:
assert k in self.config.column_map
# initialize the schema
dim = len(embedding.embed_query("test"))
self.schema = f"""\
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
{self.config.column_map['id']} varchar(50),
{self.config.column_map['document']} string,
{self.config.column_map['embedding']} array<float>,
{self.config.column_map['metadata']} string
) ENGINE = OLAP UNIQUE KEY(id) DISTRIBUTED BY HASH(id) \
PROPERTIES ("replication_allocation" = "tag.location.default: 1")\
"""
self.dim = dim
self.BS = "\\"
self.must_escape = ("\\", "'")
self._embedding = embedding
self.dist_order = "DESC"
_debug_output(self.config)
# Create a connection to Apache Doris
self.connection = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.username,
password=self.config.password,
database=self.config.database,
**kwargs,
)
_debug_output(self.schema)
_get_named_result(self.connection, self.schema)
def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
@property
def embeddings(self) -> Embeddings:
return self._embedding
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
ks = ",".join(column_names)
embed_tuple_index = tuple(column_names).index(
self.config.column_map["embedding"]
)
_data = []
for n in transac:
n = ",".join(
[
(
f"'{self.escape_str(str(_n))}'"
if idx != embed_tuple_index
else f"array<float>{str(_n)}"
)
for (idx, _n) in enumerate(n)
]
)
_data.append(f"({n})")
i_str = f"""
INSERT INTO
{self.config.database}.{self.config.table}({ks})
VALUES
{','.join(_data)}
"""
return i_str
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
_insert_query = self._build_insert_sql(transac, column_names)
_debug_output(_insert_query)
_get_named_result(self.connection, _insert_query)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 32,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Insert more texts through the embeddings and add to the VectorStore.
Args:
texts: Iterable of strings to add to the VectorStore.
ids: Optional list of ids to associate with the texts.
batch_size: Batch size of insertion
metadata: Optional column data to be inserted
Returns:
List of ids from adding the texts into the VectorStore.
"""
# Embed and create the documents
ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
colmap_ = self.config.column_map
transac = []
column_names = {
colmap_["id"]: ids,
colmap_["document"]: texts,
colmap_["embedding"]: self._embedding.embed_documents(list(texts)),
}
metadatas = metadatas or [{} for _ in texts]
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
assert len(set(colmap_) - set(column_names)) >= 0
keys, values = zip(*column_names.items())
try:
t = None
for v in self.pgbar(
zip(*values), desc="Inserting data...", total=len(metadatas)
):
assert (
len(v[keys.index(self.config.column_map["embedding"])]) == self.dim
)
transac.append(v)
if len(transac) == batch_size:
if t:
t.join()
t = Thread(target=self._insert, args=[transac, keys])
t.start()
transac = []
if len(transac) > 0:
if t:
t.join()
self._insert(transac, keys)
return [i for i in ids]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
config: Optional[ApacheDorisSettings] = None,
text_ids: Optional[Iterable[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> ApacheDoris:
"""Create Apache Doris wrapper with existing texts
Args:
embedding_function (Embeddings): Function to extract text embedding
texts (Iterable[str]): List or tuple of strings to be added
config (ApacheDorisSettings, Optional): Apache Doris configuration
text_ids (Optional[Iterable], optional): IDs for the texts.
Defaults to None.
batch_size (int, optional): BatchSize when transmitting data to Apache
Doris. Defaults to 32.
metadata (List[dict], optional): metadata to texts. Defaults to None.
Returns:
Apache Doris Index
"""
ctx = cls(embedding, config=config, **kwargs)
ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
return ctx
def __repr__(self) -> str:
"""Text representation for Apache Doris Vector Store, prints frontends, username
and schemas. Easy to use with `str(ApacheDoris())`
Returns:
repr: string to show connection info and data schema
"""
_repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
_repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
_repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
width = 25
fields = 3
_repr += "-" * (width * fields + 1) + "\n"
columns = ["name", "type", "key"]
_repr += f"|\033[94m{columns[0]:24s}\033[0m|\033[96m{columns[1]:24s}"
_repr += f"\033[0m|\033[96m{columns[2]:24s}\033[0m|\n"
_repr += "-" * (width * fields + 1) + "\n"
q_str = f"DESC {self.config.database}.{self.config.table}"
_debug_output(q_str)
rs = _get_named_result(self.connection, q_str)
for r in rs:
_repr += f"|\033[94m{r['Field']:24s}\033[0m|\033[96m{r['Type']:24s}"
_repr += f"\033[0m|\033[96m{r['Key']:24s}\033[0m|\n"
_repr += "-" * (width * fields + 1) + "\n"
return _repr
def _build_query_sql(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"WHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['document']},
{self.config.column_map['metadata']},
cosine_distance(array<float>[{q_emb_str}],
{self.config.column_map['embedding']}) as dist
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY dist {self.dist_order}
LIMIT {topk}
"""
_debug_output(q_str)
return q_str
def similarity_search(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Document]:
"""Perform a similarity search with Apache Doris
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of Documents
"""
return self.similarity_search_by_vector(
self._embedding.embed_query(query), k, where_str, **kwargs
)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search with Apache Doris by vectors
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of (Document, similarity)
"""
q_str = self._build_query_sql(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
)
for r in _get_named_result(self.connection, q_str)
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Perform a similarity search with Apache Doris
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of documents
"""
q_str = self._build_query_sql(self._embedding.embed_query(query), k, where_str)
try:
return [
(
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
),
r["dist"],
)
for r in _get_named_result(self.connection, q_str)
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
def drop(self) -> None:
"""
Helper function: Drop data
"""
_get_named_result(
self.connection,
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}",
)
@property
def metadata_column(self) -> str:
return self.config.column_map["metadata"]
def _has_mul_sub_str(s: str, *args: Any) -> bool:
"""Check if a string has multiple substrings.
Args:
s: The string to check
*args: The substrings to check for in the string
Returns:
bool: True if all substrings are present in the string, False otherwise
"""
for a in args:
if a not in s:
return False
return True
def _debug_output(s: Any) -> None:
"""Print a debug message if DEBUG is True.
Args:
s: The message to print
"""
if DEBUG:
print(s) # noqa: T201
def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
"""Get a named result from a query.
Args:
connection: The connection to the database
query: The query to execute
Returns:
List[dict[str, Any]]: The result of the query
"""
cursor = connection.cursor()
cursor.execute(query)
columns = cursor.description
result = []
for value in cursor.fetchall():
r = {}
for idx, datum in enumerate(value):
k = columns[idx][0]
r[k] = datum
result.append(r)
_debug_output(result)
cursor.close()
return result