[Feature][VectorStore] Support StarRocks as vector db (#6119)

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

<!-- Remove if not applicable -->

Fixes # (issue)

#### Before submitting

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

Here are some examples to use StarRocks as vectordb

```
from langchain.vectorstores import StarRocks
from langchain.vectorstores.starrocks import StarRocksSettings

embeddings = OpenAIEmbeddings()

# conifgure starrocks settings
settings = StarRocksSettings()
settings.port = 41003
settings.host = '127.0.0.1'
settings.username = 'root'
settings.password = ''
settings.database = 'zya'

# to fill new embeddings
docsearch = StarRocks.from_documents(split_docs, embeddings, config = settings)   


# or to use already-built embeddings in database.
docsearch = StarRocks(embeddings, settings)
```

#### Who can review?

Tag maintainers/contributors who might be interested:

@dev2049 

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @hwchase17

  VectorStores / Retrievers / Memory
  - @dev2049

 -->

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
multi_strategy_parser
dirtysalt 11 months ago committed by GitHub
parent 7a4ff424fc
commit 57cc3d1d3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,313 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "59723cea",
"metadata": {},
"source": [
"# StarRocks\n",
"\n",
"[StarRocks | A High-Performance Analytical Database](https://www.starrocks.io/)\n",
"\n",
"StarRocks is a next-gen sub-second MPP database for full analytics scenarios, including multi-dimensional analytics, real-time analytics and ad-hoc query.\n",
"\n",
"Usually StarRocks is categorized into OLAP, and it has showed excellent performance in [ClickBench — a Benchmark For Analytical DBMS](https://benchmark.clickhouse.com/). Since it has a super-fast vectorized execution engine, it could also be used as a fast vectordb.\n",
"\n",
"Here we'll show how to use the StarRocks Vector Store."
]
},
{
"cell_type": "markdown",
"id": "1685854f",
"metadata": {},
"source": [
"\n",
"## Import all used modules"
]
},
{
"cell_type": "markdown",
"id": "2c891bba",
"metadata": {},
"source": [
"Set `update_vectordb = False` at the beginning. If there is no docs updated, then we don't need to rebuild the embeddings of docs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3c85fb93",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/dirlt/utils/py3env/lib/python3.9/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.7) or chardet (5.1.0)/charset_normalizer (2.0.9) doesn't match a supported version!\n",
" warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n"
]
}
],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import StarRocks\n",
"from langchain.vectorstores.starrocks import StarRocksSettings\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter\n",
"from langchain import OpenAI,VectorDBQA\n",
"from langchain.document_loaders import DirectoryLoader\n",
"from langchain.chains import RetrievalQA\n",
"from langchain.document_loaders import TextLoader, UnstructuredMarkdownLoader\n",
"\n",
"update_vectordb = False"
]
},
{
"cell_type": "markdown",
"id": "ee821c00",
"metadata": {},
"source": [
"## Load docs and split them into tokens"
]
},
{
"cell_type": "markdown",
"id": "34ba0cfd",
"metadata": {},
"source": [
"Load all markdown files under the `docs` directory\n",
"\n",
"for starrocks documents, you can clone repo from https://github.com/StarRocks/starrocks, and there is `docs` directory in it."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "85912696",
"metadata": {},
"outputs": [],
"source": [
"loader = DirectoryLoader('./docs', glob='**/*.md', loader_cls=UnstructuredMarkdownLoader)\n",
"documents = loader.load()"
]
},
{
"cell_type": "markdown",
"id": "b415fe2a",
"metadata": {},
"source": [
"Split docs into tokens, and set `update_vectordb = True` because there are new docs/tokens."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "07e8acff",
"metadata": {},
"outputs": [],
"source": [
"# load text splitter and split docs into snippets of text\n",
"text_splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)\n",
"split_docs = text_splitter.split_documents(documents)\n",
"\n",
"# tell vectordb to update text embeddings\n",
"update_vectordb = True"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1f365370",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='Compile StarRocks with Docker\\n\\nThis topic describes how to compile StarRocks using Docker.\\n\\nOverview\\n\\nStarRocks provides development environment images for both Ubuntu 22.04 and CentOS 7.9. With the image, you can launch a Docker container and compile StarRocks in the container.\\n\\nStarRocks version and DEV ENV image\\n\\nDifferent branches of StarRocks correspond to different development environment images provided on StarRocks Docker Hub.\\n\\nFor Ubuntu 22.04:\\n\\n| Branch name | Image name |\\n | --------------- | ----------------------------------- |\\n | main | starrocks/dev-env-ubuntu:latest |\\n | branch-3.0 | starrocks/dev-env-ubuntu:3.0-latest |\\n | branch-2.5 | starrocks/dev-env-ubuntu:2.5-latest |\\n\\nFor CentOS 7.9:\\n\\n| Branch name | Image name |\\n | --------------- | ------------------------------------ |\\n | main | starrocks/dev-env-centos7:latest |\\n | branch-3.0 | starrocks/dev-env-centos7:3.0-latest |\\n | branch-2.5 | starrocks/dev-env-centos7:2.5-latest |\\n\\nPrerequisites\\n\\nBefore compiling StarRocks, make sure the following requirements are satisfied:\\n\\nHardware\\n\\n', metadata={'source': 'docs/developers/build-starrocks/Build_in_docker.md'})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_docs[-20]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "50012b29",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# docs = 657, # splits = 2802\n"
]
}
],
"source": [
"print('# docs = %d, # splits = %d' % (len(documents), len(split_docs)))"
]
},
{
"cell_type": "markdown",
"id": "5371f152",
"metadata": {},
"source": [
"## Create vectordb instance"
]
},
{
"cell_type": "markdown",
"id": "15702d9c",
"metadata": {},
"source": [
"### Use StarRocks as vectordb"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ced7dbe1",
"metadata": {},
"outputs": [],
"source": [
"def gen_starrocks(update_vectordb, embeddings, settings):\n",
" if update_vectordb:\n",
" docsearch = StarRocks.from_documents(split_docs, embeddings, config = settings) \n",
" else:\n",
" docsearch = StarRocks(embeddings, settings) \n",
" return docsearch\n"
]
},
{
"cell_type": "markdown",
"id": "15d86fda",
"metadata": {},
"source": [
"## Convert tokens into embeddings and put them into vectordb"
]
},
{
"cell_type": "markdown",
"id": "ff1322ea",
"metadata": {},
"source": [
"Here we use StarRocks as vectordb, you can configure StarRocks instance via `StarRocksSettings`.\n",
"\n",
"Configuring StarRocks instance is pretty much like configuring mysql instance. You need to specify:\n",
"1. host/port\n",
"2. username(default: 'root')\n",
"3. password(default: '')\n",
"4. database(default: 'default')\n",
"5. table(default: 'langchain')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "26410d9b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting data...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2802/2802 [02:26<00:00, 19.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[92m\u001b[1mzya.langchain @ 127.0.0.1:41003\u001b[0m\n",
"\n",
"\u001b[1musername: root\u001b[0m\n",
"\n",
"Table Schema:\n",
"----------------------------------------------------------------------------\n",
"|\u001b[94mname \u001b[0m|\u001b[96mtype \u001b[0m|\u001b[96mkey \u001b[0m|\n",
"----------------------------------------------------------------------------\n",
"|\u001b[94mid \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mtrue \u001b[0m|\n",
"|\u001b[94mdocument \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
"|\u001b[94membedding \u001b[0m|\u001b[96marray<float> \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
"|\u001b[94mmetadata \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
"----------------------------------------------------------------------------\n",
"\n"
]
}
],
"source": [
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# configure starrocks settings(host/port/user/pw/db)\n",
"settings = StarRocksSettings()\n",
"settings.port = 41003\n",
"settings.host = '127.0.0.1'\n",
"settings.username = 'root'\n",
"settings.password = ''\n",
"settings.database = 'zya'\n",
"docsearch = gen_starrocks(update_vectordb, embeddings, settings)\n",
"\n",
"print(docsearch)\n",
"\n",
"update_vectordb = False"
]
},
{
"cell_type": "markdown",
"id": "bde66626",
"metadata": {},
"source": [
"## Build QA and ask question to it"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "84921814",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" No, profile is not enabled by default. To enable profile, set the variable `enable_profile` to `true` using the command `set enable_profile = true;`\n"
]
}
],
"source": [
"llm = OpenAI()\n",
"qa = RetrievalQA.from_chain_type(llm=llm, chain_type=\"stuff\", retriever=docsearch.as_retriever())\n",
"query = \"is profile enabled by default? if not, how to enable profile?\"\n",
"resp = qa.run(query)\n",
"print(resp)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -29,6 +29,7 @@ from langchain.vectorstores.redis import Redis
from langchain.vectorstores.rocksetdb import Rockset
from langchain.vectorstores.singlestoredb import SingleStoreDB
from langchain.vectorstores.sklearn import SKLearnVectorStore
from langchain.vectorstores.starrocks import StarRocks
from langchain.vectorstores.supabase import SupabaseVectorStore
from langchain.vectorstores.tair import Tair
from langchain.vectorstores.tigris import Tigris
@ -68,6 +69,7 @@ __all__ = [
"Rockset",
"SKLearnVectorStore",
"SingleStoreDB",
"StarRocks",
"SupabaseVectorStore",
"Tair",
"Tigris",

@ -0,0 +1,458 @@
"""Wrapper around open source StarRocks VectorSearch capability."""
from __future__ import annotations
import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple
from pydantic import BaseSettings
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
logger = logging.getLogger()
DEBUG = False
def has_mul_sub_str(s: str, *args: Any) -> bool:
for a in args:
if a not in s:
return False
return True
def debug_output(s: Any) -> None:
if DEBUG:
print(s)
def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
cursor = connection.cursor()
cursor.execute(query)
columns = cursor.description
result = []
for value in cursor.fetchall():
r = {}
for idx, datum in enumerate(value):
k = columns[idx][0]
r[k] = datum
result.append(r)
debug_output(result)
cursor.close()
return result
class StarRocksSettings(BaseSettings):
"""StarRocks Client Configuration
Attribute:
StarRocks_host (str) : An URL to connect to MyScale backend.
Defaults to 'localhost'.
StarRocks_port (int) : URL port to connect with HTTP. Defaults to 8443.
username (str) : Username to login. Defaults to None.
password (str) : Password to login. Defaults to None.
database (str) : Database name to find the table. Defaults to 'default'.
table (str) : Table name to operate on.
Defaults to 'vector_table'.
column_map (Dict) : Column type map to project column name onto langchain
semantics. Must have keys: `text`, `id`, `vector`,
must be same size to number of columns. For example:
.. code-block:: python
{
'id': 'text_id',
'embedding': 'text_embedding',
'document': 'text_plain',
'metadata': 'metadata_dictionary_in_json',
}
Defaults to identity map.
"""
host: str = "localhost"
port: int = 9030
username: str = "root"
password: str = ""
column_map: Dict[str, str] = {
"id": "id",
"document": "document",
"embedding": "embedding",
"metadata": "metadata",
}
database: str = "default"
table: str = "langchain"
def __getitem__(self, item: str) -> Any:
return getattr(self, item)
class Config:
env_file = ".env"
env_prefix = "starrocks_"
env_file_encoding = "utf-8"
class StarRocks(VectorStore):
"""Wrapper around StarRocks vector database
You need a `pymysql` python package, and a valid account
to connect to StarRocks.
Right now StarRocks has only implemented `cosine_similarity` function to
compute distance between two vectors. And there is no vector inside right now,
so we have to iterate all vectors and compute spatial distance.
For more information, please visit
[StarRocks official site](https://www.starrocks.io/)
[StarRocks github](https://github.com/StarRocks/starrocks)
"""
def __init__(
self,
embedding: Embeddings,
config: Optional[StarRocksSettings] = None,
**kwargs: Any,
) -> None:
"""StarRocks Wrapper to LangChain
embedding_function (Embeddings):
config (StarRocksSettings): Configuration to StarRocks Client
"""
try:
import pymysql # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import pymysql python package. "
"Please install it with `pip install pymysql`."
)
try:
from tqdm import tqdm
self.pgbar = tqdm
except ImportError:
# Just in case if tqdm is not installed
self.pgbar = lambda x, **kwargs: x
super().__init__()
if config is not None:
self.config = config
else:
self.config = StarRocksSettings()
assert self.config
assert self.config.host and self.config.port
assert self.config.column_map and self.config.database and self.config.table
for k in ["id", "embedding", "document", "metadata"]:
assert k in self.config.column_map
# initialize the schema
dim = len(embedding.embed_query("test"))
self.schema = f"""\
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
{self.config.column_map['id']} string,
{self.config.column_map['document']} string,
{self.config.column_map['embedding']} array<float>,
{self.config.column_map['metadata']} string
) ENGINE = OLAP PRIMARY KEY(id) DISTRIBUTED BY HASH(id) \
PROPERTIES ("replication_num" = "1")\
"""
self.dim = dim
self.BS = "\\"
self.must_escape = ("\\", "'")
self.embedding_function = embedding
self.dist_order = "DESC"
debug_output(self.config)
# Create a connection to StarRocks
self.connection = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.username,
password=self.config.password,
database=self.config.database,
**kwargs,
)
debug_output(self.schema)
get_named_result(self.connection, self.schema)
def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
ks = ",".join(column_names)
embed_tuple_index = tuple(column_names).index(
self.config.column_map["embedding"]
)
_data = []
for n in transac:
n = ",".join(
[
f"'{self.escape_str(str(_n))}'"
if idx != embed_tuple_index
else f"array<float>{str(_n)}"
for (idx, _n) in enumerate(n)
]
)
_data.append(f"({n})")
i_str = f"""
INSERT INTO
{self.config.database}.{self.config.table}({ks})
VALUES
{','.join(_data)}
"""
return i_str
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
_insert_query = self._build_insert_sql(transac, column_names)
debug_output(_insert_query)
get_named_result(self.connection, _insert_query)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 32,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Insert more texts through the embeddings and add to the VectorStore.
Args:
texts: Iterable of strings to add to the VectorStore.
ids: Optional list of ids to associate with the texts.
batch_size: Batch size of insertion
metadata: Optional column data to be inserted
Returns:
List of ids from adding the texts into the VectorStore.
"""
# Embed and create the documents
ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
colmap_ = self.config.column_map
transac = []
column_names = {
colmap_["id"]: ids,
colmap_["document"]: texts,
colmap_["embedding"]: self.embedding_function.embed_documents(list(texts)),
}
metadatas = metadatas or [{} for _ in texts]
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
assert len(set(colmap_) - set(column_names)) >= 0
keys, values = zip(*column_names.items())
try:
t = None
for v in self.pgbar(
zip(*values), desc="Inserting data...", total=len(metadatas)
):
assert (
len(v[keys.index(self.config.column_map["embedding"])]) == self.dim
)
transac.append(v)
if len(transac) == batch_size:
if t:
t.join()
t = Thread(target=self._insert, args=[transac, keys])
t.start()
transac = []
if len(transac) > 0:
if t:
t.join()
self._insert(transac, keys)
return [i for i in ids]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
config: Optional[StarRocksSettings] = None,
text_ids: Optional[Iterable[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> StarRocks:
"""Create StarRocks wrapper with existing texts
Args:
embedding_function (Embeddings): Function to extract text embedding
texts (Iterable[str]): List or tuple of strings to be added
config (StarRocksSettings, Optional): StarRocks configuration
text_ids (Optional[Iterable], optional): IDs for the texts.
Defaults to None.
batch_size (int, optional): Batchsize when transmitting data to StarRocks.
Defaults to 32.
metadata (List[dict], optional): metadata to texts. Defaults to None.
Returns:
StarRocks Index
"""
ctx = cls(embedding, config, **kwargs)
ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
return ctx
def __repr__(self) -> str:
"""Text representation for StarRocks Vector Store, prints backends, username
and schemas. Easy to use with `str(StarRocks())`
Returns:
repr: string to show connection info and data schema
"""
_repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
_repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
_repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
width = 25
fields = 3
_repr += "-" * (width * fields + 1) + "\n"
columns = ["name", "type", "key"]
_repr += f"|\033[94m{columns[0]:24s}\033[0m|\033[96m{columns[1]:24s}"
_repr += f"\033[0m|\033[96m{columns[2]:24s}\033[0m|\n"
_repr += "-" * (width * fields + 1) + "\n"
q_str = f"DESC {self.config.database}.{self.config.table}"
debug_output(q_str)
rs = get_named_result(self.connection, q_str)
for r in rs:
_repr += f"|\033[94m{r['Field']:24s}\033[0m|\033[96m{r['Type']:24s}"
_repr += f"\033[0m|\033[96m{r['Key']:24s}\033[0m|\n"
_repr += "-" * (width * fields + 1) + "\n"
return _repr
def _build_query_sql(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"WHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['document']},
{self.config.column_map['metadata']},
cosine_similarity_norm(array<float>[{q_emb_str}],
{self.config.column_map['embedding']}) as dist
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY dist {self.dist_order}
LIMIT {topk}
"""
debug_output(q_str)
return q_str
def similarity_search(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Document]:
"""Perform a similarity search with StarRocks
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of Documents
"""
return self.similarity_search_by_vector(
self.embedding_function.embed_query(query), k, where_str, **kwargs
)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search with StarRocks by vectors
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of (Document, similarity)
"""
q_str = self._build_query_sql(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
)
for r in get_named_result(self.connection, q_str)
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Perform a similarity search with StarRocks
Args:
query (str): query string
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): where condition string.
Defaults to None.
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection. When dealing with metadatas, remember to
use `{self.metadata_column}.attribute` instead of `attribute`
alone. The default name for it is `metadata`.
Returns:
List[Document]: List of documents
"""
q_str = self._build_query_sql(
self.embedding_function.embed_query(query), k, where_str
)
try:
return [
(
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
),
r["dist"],
)
for r in get_named_result(self.connection, q_str)
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
def drop(self) -> None:
"""
Helper function: Drop data
"""
get_named_result(
self.connection,
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}",
)
@property
def metadata_column(self) -> str:
return self.config.column_map["metadata"]
Loading…
Cancel
Save