mirror of https://github.com/hwchase17/langchain
[Feature][VectorStore] Support StarRocks as vector db (#6119)
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> Fixes # (issue) #### Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> Here are some examples to use StarRocks as vectordb ``` from langchain.vectorstores import StarRocks from langchain.vectorstores.starrocks import StarRocksSettings embeddings = OpenAIEmbeddings() # conifgure starrocks settings settings = StarRocksSettings() settings.port = 41003 settings.host = '127.0.0.1' settings.username = 'root' settings.password = '' settings.database = 'zya' # to fill new embeddings docsearch = StarRocks.from_documents(split_docs, embeddings, config = settings) # or to use already-built embeddings in database. docsearch = StarRocks(embeddings, settings) ``` #### Who can review? Tag maintainers/contributors who might be interested: @dev2049 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @hwchase17 VectorStores / Retrievers / Memory - @dev2049 --> --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/6555/head
parent
7a4ff424fc
commit
57cc3d1d3d
@ -0,0 +1,313 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "59723cea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# StarRocks\n",
|
||||
"\n",
|
||||
"[StarRocks | A High-Performance Analytical Database](https://www.starrocks.io/)\n",
|
||||
"\n",
|
||||
"StarRocks is a next-gen sub-second MPP database for full analytics scenarios, including multi-dimensional analytics, real-time analytics and ad-hoc query.\n",
|
||||
"\n",
|
||||
"Usually StarRocks is categorized into OLAP, and it has showed excellent performance in [ClickBench — a Benchmark For Analytical DBMS](https://benchmark.clickhouse.com/). Since it has a super-fast vectorized execution engine, it could also be used as a fast vectordb.\n",
|
||||
"\n",
|
||||
"Here we'll show how to use the StarRocks Vector Store."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1685854f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"## Import all used modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c891bba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set `update_vectordb = False` at the beginning. If there is no docs updated, then we don't need to rebuild the embeddings of docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3c85fb93",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/dirlt/utils/py3env/lib/python3.9/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.7) or chardet (5.1.0)/charset_normalizer (2.0.9) doesn't match a supported version!\n",
|
||||
" warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import StarRocks\n",
|
||||
"from langchain.vectorstores.starrocks import StarRocksSettings\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter\n",
|
||||
"from langchain import OpenAI,VectorDBQA\n",
|
||||
"from langchain.document_loaders import DirectoryLoader\n",
|
||||
"from langchain.chains import RetrievalQA\n",
|
||||
"from langchain.document_loaders import TextLoader, UnstructuredMarkdownLoader\n",
|
||||
"\n",
|
||||
"update_vectordb = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee821c00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load docs and split them into tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "34ba0cfd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load all markdown files under the `docs` directory\n",
|
||||
"\n",
|
||||
"for starrocks documents, you can clone repo from https://github.com/StarRocks/starrocks, and there is `docs` directory in it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "85912696",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = DirectoryLoader('./docs', glob='**/*.md', loader_cls=UnstructuredMarkdownLoader)\n",
|
||||
"documents = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b415fe2a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Split docs into tokens, and set `update_vectordb = True` because there are new docs/tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "07e8acff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load text splitter and split docs into snippets of text\n",
|
||||
"text_splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)\n",
|
||||
"split_docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"# tell vectordb to update text embeddings\n",
|
||||
"update_vectordb = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1f365370",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='Compile StarRocks with Docker\\n\\nThis topic describes how to compile StarRocks using Docker.\\n\\nOverview\\n\\nStarRocks provides development environment images for both Ubuntu 22.04 and CentOS 7.9. With the image, you can launch a Docker container and compile StarRocks in the container.\\n\\nStarRocks version and DEV ENV image\\n\\nDifferent branches of StarRocks correspond to different development environment images provided on StarRocks Docker Hub.\\n\\nFor Ubuntu 22.04:\\n\\n| Branch name | Image name |\\n | --------------- | ----------------------------------- |\\n | main | starrocks/dev-env-ubuntu:latest |\\n | branch-3.0 | starrocks/dev-env-ubuntu:3.0-latest |\\n | branch-2.5 | starrocks/dev-env-ubuntu:2.5-latest |\\n\\nFor CentOS 7.9:\\n\\n| Branch name | Image name |\\n | --------------- | ------------------------------------ |\\n | main | starrocks/dev-env-centos7:latest |\\n | branch-3.0 | starrocks/dev-env-centos7:3.0-latest |\\n | branch-2.5 | starrocks/dev-env-centos7:2.5-latest |\\n\\nPrerequisites\\n\\nBefore compiling StarRocks, make sure the following requirements are satisfied:\\n\\nHardware\\n\\n', metadata={'source': 'docs/developers/build-starrocks/Build_in_docker.md'})"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"split_docs[-20]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "50012b29",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"# docs = 657, # splits = 2802\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print('# docs = %d, # splits = %d' % (len(documents), len(split_docs)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5371f152",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create vectordb instance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "15702d9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Use StarRocks as vectordb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "ced7dbe1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def gen_starrocks(update_vectordb, embeddings, settings):\n",
|
||||
" if update_vectordb:\n",
|
||||
" docsearch = StarRocks.from_documents(split_docs, embeddings, config = settings) \n",
|
||||
" else:\n",
|
||||
" docsearch = StarRocks(embeddings, settings) \n",
|
||||
" return docsearch\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "15d86fda",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Convert tokens into embeddings and put them into vectordb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff1322ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we use StarRocks as vectordb, you can configure StarRocks instance via `StarRocksSettings`.\n",
|
||||
"\n",
|
||||
"Configuring StarRocks instance is pretty much like configuring mysql instance. You need to specify:\n",
|
||||
"1. host/port\n",
|
||||
"2. username(default: 'root')\n",
|
||||
"3. password(default: '')\n",
|
||||
"4. database(default: 'default')\n",
|
||||
"5. table(default: 'langchain')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "26410d9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inserting data...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2802/2802 [02:26<00:00, 19.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[92m\u001b[1mzya.langchain @ 127.0.0.1:41003\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1musername: root\u001b[0m\n",
|
||||
"\n",
|
||||
"Table Schema:\n",
|
||||
"----------------------------------------------------------------------------\n",
|
||||
"|\u001b[94mname \u001b[0m|\u001b[96mtype \u001b[0m|\u001b[96mkey \u001b[0m|\n",
|
||||
"----------------------------------------------------------------------------\n",
|
||||
"|\u001b[94mid \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mtrue \u001b[0m|\n",
|
||||
"|\u001b[94mdocument \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
|
||||
"|\u001b[94membedding \u001b[0m|\u001b[96marray<float> \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
|
||||
"|\u001b[94mmetadata \u001b[0m|\u001b[96mvarchar(65533) \u001b[0m|\u001b[96mfalse \u001b[0m|\n",
|
||||
"----------------------------------------------------------------------------\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"# configure starrocks settings(host/port/user/pw/db)\n",
|
||||
"settings = StarRocksSettings()\n",
|
||||
"settings.port = 41003\n",
|
||||
"settings.host = '127.0.0.1'\n",
|
||||
"settings.username = 'root'\n",
|
||||
"settings.password = ''\n",
|
||||
"settings.database = 'zya'\n",
|
||||
"docsearch = gen_starrocks(update_vectordb, embeddings, settings)\n",
|
||||
"\n",
|
||||
"print(docsearch)\n",
|
||||
"\n",
|
||||
"update_vectordb = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bde66626",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build QA and ask question to it"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "84921814",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" No, profile is not enabled by default. To enable profile, set the variable `enable_profile` to `true` using the command `set enable_profile = true;`\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = OpenAI()\n",
|
||||
"qa = RetrievalQA.from_chain_type(llm=llm, chain_type=\"stuff\", retriever=docsearch.as_retriever())\n",
|
||||
"query = \"is profile enabled by default? if not, how to enable profile?\"\n",
|
||||
"resp = qa.run(query)\n",
|
||||
"print(resp)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,458 @@
|
||||
"""Wrapper around open source StarRocks VectorSearch capability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from hashlib import sha1
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger()
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def has_mul_sub_str(s: str, *args: Any) -> bool:
|
||||
for a in args:
|
||||
if a not in s:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def debug_output(s: Any) -> None:
|
||||
if DEBUG:
|
||||
print(s)
|
||||
|
||||
|
||||
def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
columns = cursor.description
|
||||
result = []
|
||||
for value in cursor.fetchall():
|
||||
r = {}
|
||||
for idx, datum in enumerate(value):
|
||||
k = columns[idx][0]
|
||||
r[k] = datum
|
||||
result.append(r)
|
||||
debug_output(result)
|
||||
cursor.close()
|
||||
return result
|
||||
|
||||
|
||||
class StarRocksSettings(BaseSettings):
|
||||
"""StarRocks Client Configuration
|
||||
|
||||
Attribute:
|
||||
StarRocks_host (str) : An URL to connect to MyScale backend.
|
||||
Defaults to 'localhost'.
|
||||
StarRocks_port (int) : URL port to connect with HTTP. Defaults to 8443.
|
||||
username (str) : Username to login. Defaults to None.
|
||||
password (str) : Password to login. Defaults to None.
|
||||
database (str) : Database name to find the table. Defaults to 'default'.
|
||||
table (str) : Table name to operate on.
|
||||
Defaults to 'vector_table'.
|
||||
|
||||
column_map (Dict) : Column type map to project column name onto langchain
|
||||
semantics. Must have keys: `text`, `id`, `vector`,
|
||||
must be same size to number of columns. For example:
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'id': 'text_id',
|
||||
'embedding': 'text_embedding',
|
||||
'document': 'text_plain',
|
||||
'metadata': 'metadata_dictionary_in_json',
|
||||
}
|
||||
|
||||
Defaults to identity map.
|
||||
"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 9030
|
||||
username: str = "root"
|
||||
password: str = ""
|
||||
|
||||
column_map: Dict[str, str] = {
|
||||
"id": "id",
|
||||
"document": "document",
|
||||
"embedding": "embedding",
|
||||
"metadata": "metadata",
|
||||
}
|
||||
|
||||
database: str = "default"
|
||||
table: str = "langchain"
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return getattr(self, item)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_prefix = "starrocks_"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
class StarRocks(VectorStore):
|
||||
"""Wrapper around StarRocks vector database
|
||||
|
||||
You need a `pymysql` python package, and a valid account
|
||||
to connect to StarRocks.
|
||||
|
||||
Right now StarRocks has only implemented `cosine_similarity` function to
|
||||
compute distance between two vectors. And there is no vector inside right now,
|
||||
so we have to iterate all vectors and compute spatial distance.
|
||||
|
||||
For more information, please visit
|
||||
[StarRocks official site](https://www.starrocks.io/)
|
||||
[StarRocks github](https://github.com/StarRocks/starrocks)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
config: Optional[StarRocksSettings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""StarRocks Wrapper to LangChain
|
||||
|
||||
embedding_function (Embeddings):
|
||||
config (StarRocksSettings): Configuration to StarRocks Client
|
||||
"""
|
||||
try:
|
||||
import pymysql # type: ignore[import]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pymysql python package. "
|
||||
"Please install it with `pip install pymysql`."
|
||||
)
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
self.pgbar = tqdm
|
||||
except ImportError:
|
||||
# Just in case if tqdm is not installed
|
||||
self.pgbar = lambda x, **kwargs: x
|
||||
super().__init__()
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = StarRocksSettings()
|
||||
assert self.config
|
||||
assert self.config.host and self.config.port
|
||||
assert self.config.column_map and self.config.database and self.config.table
|
||||
for k in ["id", "embedding", "document", "metadata"]:
|
||||
assert k in self.config.column_map
|
||||
|
||||
# initialize the schema
|
||||
dim = len(embedding.embed_query("test"))
|
||||
|
||||
self.schema = f"""\
|
||||
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
{self.config.column_map['id']} string,
|
||||
{self.config.column_map['document']} string,
|
||||
{self.config.column_map['embedding']} array<float>,
|
||||
{self.config.column_map['metadata']} string
|
||||
) ENGINE = OLAP PRIMARY KEY(id) DISTRIBUTED BY HASH(id) \
|
||||
PROPERTIES ("replication_num" = "1")\
|
||||
"""
|
||||
self.dim = dim
|
||||
self.BS = "\\"
|
||||
self.must_escape = ("\\", "'")
|
||||
self.embedding_function = embedding
|
||||
self.dist_order = "DESC"
|
||||
debug_output(self.config)
|
||||
|
||||
# Create a connection to StarRocks
|
||||
self.connection = pymysql.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.username,
|
||||
password=self.config.password,
|
||||
database=self.config.database,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
debug_output(self.schema)
|
||||
get_named_result(self.connection, self.schema)
|
||||
|
||||
def escape_str(self, value: str) -> str:
|
||||
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
|
||||
|
||||
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
|
||||
ks = ",".join(column_names)
|
||||
embed_tuple_index = tuple(column_names).index(
|
||||
self.config.column_map["embedding"]
|
||||
)
|
||||
_data = []
|
||||
for n in transac:
|
||||
n = ",".join(
|
||||
[
|
||||
f"'{self.escape_str(str(_n))}'"
|
||||
if idx != embed_tuple_index
|
||||
else f"array<float>{str(_n)}"
|
||||
for (idx, _n) in enumerate(n)
|
||||
]
|
||||
)
|
||||
_data.append(f"({n})")
|
||||
i_str = f"""
|
||||
INSERT INTO
|
||||
{self.config.database}.{self.config.table}({ks})
|
||||
VALUES
|
||||
{','.join(_data)}
|
||||
"""
|
||||
return i_str
|
||||
|
||||
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
|
||||
_insert_query = self._build_insert_sql(transac, column_names)
|
||||
debug_output(_insert_query)
|
||||
get_named_result(self.connection, _insert_query)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
batch_size: int = 32,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Insert more texts through the embeddings and add to the VectorStore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the VectorStore.
|
||||
ids: Optional list of ids to associate with the texts.
|
||||
batch_size: Batch size of insertion
|
||||
metadata: Optional column data to be inserted
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the VectorStore.
|
||||
|
||||
"""
|
||||
# Embed and create the documents
|
||||
ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
|
||||
colmap_ = self.config.column_map
|
||||
transac = []
|
||||
column_names = {
|
||||
colmap_["id"]: ids,
|
||||
colmap_["document"]: texts,
|
||||
colmap_["embedding"]: self.embedding_function.embed_documents(list(texts)),
|
||||
}
|
||||
metadatas = metadatas or [{} for _ in texts]
|
||||
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
|
||||
assert len(set(colmap_) - set(column_names)) >= 0
|
||||
keys, values = zip(*column_names.items())
|
||||
try:
|
||||
t = None
|
||||
for v in self.pgbar(
|
||||
zip(*values), desc="Inserting data...", total=len(metadatas)
|
||||
):
|
||||
assert (
|
||||
len(v[keys.index(self.config.column_map["embedding"])]) == self.dim
|
||||
)
|
||||
transac.append(v)
|
||||
if len(transac) == batch_size:
|
||||
if t:
|
||||
t.join()
|
||||
t = Thread(target=self._insert, args=[transac, keys])
|
||||
t.start()
|
||||
transac = []
|
||||
if len(transac) > 0:
|
||||
if t:
|
||||
t.join()
|
||||
self._insert(transac, keys)
|
||||
return [i for i in ids]
|
||||
except Exception as e:
|
||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
config: Optional[StarRocksSettings] = None,
|
||||
text_ids: Optional[Iterable[str]] = None,
|
||||
batch_size: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> StarRocks:
|
||||
"""Create StarRocks wrapper with existing texts
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function to extract text embedding
|
||||
texts (Iterable[str]): List or tuple of strings to be added
|
||||
config (StarRocksSettings, Optional): StarRocks configuration
|
||||
text_ids (Optional[Iterable], optional): IDs for the texts.
|
||||
Defaults to None.
|
||||
batch_size (int, optional): Batchsize when transmitting data to StarRocks.
|
||||
Defaults to 32.
|
||||
metadata (List[dict], optional): metadata to texts. Defaults to None.
|
||||
Returns:
|
||||
StarRocks Index
|
||||
"""
|
||||
ctx = cls(embedding, config, **kwargs)
|
||||
ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
|
||||
return ctx
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Text representation for StarRocks Vector Store, prints backends, username
|
||||
and schemas. Easy to use with `str(StarRocks())`
|
||||
|
||||
Returns:
|
||||
repr: string to show connection info and data schema
|
||||
"""
|
||||
_repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
|
||||
_repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
|
||||
_repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
|
||||
width = 25
|
||||
fields = 3
|
||||
_repr += "-" * (width * fields + 1) + "\n"
|
||||
columns = ["name", "type", "key"]
|
||||
_repr += f"|\033[94m{columns[0]:24s}\033[0m|\033[96m{columns[1]:24s}"
|
||||
_repr += f"\033[0m|\033[96m{columns[2]:24s}\033[0m|\n"
|
||||
_repr += "-" * (width * fields + 1) + "\n"
|
||||
q_str = f"DESC {self.config.database}.{self.config.table}"
|
||||
debug_output(q_str)
|
||||
rs = get_named_result(self.connection, q_str)
|
||||
for r in rs:
|
||||
_repr += f"|\033[94m{r['Field']:24s}\033[0m|\033[96m{r['Type']:24s}"
|
||||
_repr += f"\033[0m|\033[96m{r['Key']:24s}\033[0m|\n"
|
||||
_repr += "-" * (width * fields + 1) + "\n"
|
||||
return _repr
|
||||
|
||||
def _build_query_sql(
|
||||
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
||||
) -> str:
|
||||
q_emb_str = ",".join(map(str, q_emb))
|
||||
if where_str:
|
||||
where_str = f"WHERE {where_str}"
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
q_str = f"""
|
||||
SELECT {self.config.column_map['document']},
|
||||
{self.config.column_map['metadata']},
|
||||
cosine_similarity_norm(array<float>[{q_emb_str}],
|
||||
{self.config.column_map['embedding']}) as dist
|
||||
FROM {self.config.database}.{self.config.table}
|
||||
{where_str}
|
||||
ORDER BY dist {self.dist_order}
|
||||
LIMIT {topk}
|
||||
"""
|
||||
|
||||
debug_output(q_str)
|
||||
return q_str
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search with StarRocks
|
||||
|
||||
Args:
|
||||
query (str): query string
|
||||
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
||||
where_str (Optional[str], optional): where condition string.
|
||||
Defaults to None.
|
||||
|
||||
NOTE: Please do not let end-user to fill this and always be aware
|
||||
of SQL injection. When dealing with metadatas, remember to
|
||||
use `{self.metadata_column}.attribute` instead of `attribute`
|
||||
alone. The default name for it is `metadata`.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of Documents
|
||||
"""
|
||||
return self.similarity_search_by_vector(
|
||||
self.embedding_function.embed_query(query), k, where_str, **kwargs
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search with StarRocks by vectors
|
||||
|
||||
Args:
|
||||
query (str): query string
|
||||
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
||||
where_str (Optional[str], optional): where condition string.
|
||||
Defaults to None.
|
||||
|
||||
NOTE: Please do not let end-user to fill this and always be aware
|
||||
of SQL injection. When dealing with metadatas, remember to
|
||||
use `{self.metadata_column}.attribute` instead of `attribute`
|
||||
alone. The default name for it is `metadata`.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of (Document, similarity)
|
||||
"""
|
||||
q_str = self._build_query_sql(embedding, k, where_str)
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r[self.config.column_map["document"]],
|
||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||
)
|
||||
for r in get_named_result(self.connection, q_str)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def similarity_search_with_relevance_scores(
|
||||
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a similarity search with StarRocks
|
||||
|
||||
Args:
|
||||
query (str): query string
|
||||
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
||||
where_str (Optional[str], optional): where condition string.
|
||||
Defaults to None.
|
||||
|
||||
NOTE: Please do not let end-user to fill this and always be aware
|
||||
of SQL injection. When dealing with metadatas, remember to
|
||||
use `{self.metadata_column}.attribute` instead of `attribute`
|
||||
alone. The default name for it is `metadata`.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of documents
|
||||
"""
|
||||
q_str = self._build_query_sql(
|
||||
self.embedding_function.embed_query(query), k, where_str
|
||||
)
|
||||
try:
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=r[self.config.column_map["document"]],
|
||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||
),
|
||||
r["dist"],
|
||||
)
|
||||
for r in get_named_result(self.connection, q_str)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def drop(self) -> None:
|
||||
"""
|
||||
Helper function: Drop data
|
||||
"""
|
||||
get_named_result(
|
||||
self.connection,
|
||||
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}",
|
||||
)
|
||||
|
||||
@property
|
||||
def metadata_column(self) -> str:
|
||||
return self.config.column_map["metadata"]
|
Loading…
Reference in New Issue