mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
3b52ee05d1
fix small bugs in vectorstore/baiduvectordb
438 lines
15 KiB
Python
438 lines
15 KiB
Python
"""Wrapper around the Baidu vector database."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.utils import guard_import
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionParams:
|
|
"""Baidu VectorDB Connection params.
|
|
|
|
See the following documentation for details:
|
|
https://cloud.baidu.com/doc/VDB/s/6lrsob0wy
|
|
|
|
Attribute:
|
|
endpoint (str) : The access address of the vector database server
|
|
that the client needs to connect to.
|
|
api_key (str): API key for client to access the vector database server,
|
|
which is used for authentication.
|
|
account (str) : Account for client to access the vector database server.
|
|
connection_timeout_in_mills (int) : Request Timeout.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
endpoint: str,
|
|
api_key: str,
|
|
account: str = "root",
|
|
connection_timeout_in_mills: int = 50 * 1000,
|
|
):
|
|
self.endpoint = endpoint
|
|
self.api_key = api_key
|
|
self.account = account
|
|
self.connection_timeout_in_mills = connection_timeout_in_mills
|
|
|
|
|
|
class TableParams:
|
|
"""Baidu VectorDB table params.
|
|
|
|
See the following documentation for details:
|
|
https://cloud.baidu.com/doc/VDB/s/mlrsob0p6
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dimension: int,
|
|
replication: int = 3,
|
|
partition: int = 1,
|
|
index_type: str = "HNSW",
|
|
metric_type: str = "L2",
|
|
params: Optional[Dict] = None,
|
|
):
|
|
self.dimension = dimension
|
|
self.replication = replication
|
|
self.partition = partition
|
|
self.index_type = index_type
|
|
self.metric_type = metric_type
|
|
self.params = params
|
|
|
|
|
|
class BaiduVectorDB(VectorStore):
|
|
"""Baidu VectorDB as a vector store.
|
|
|
|
In order to use this you need to have a database instance.
|
|
See the following documentation for details:
|
|
https://cloud.baidu.com/doc/VDB/index.html
|
|
"""
|
|
|
|
field_id: str = "id"
|
|
field_vector: str = "vector"
|
|
field_text: str = "text"
|
|
field_metadata: str = "metadata"
|
|
|
|
index_vector: str = "vector_idx"
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
connection_params: ConnectionParams,
|
|
table_params: TableParams = TableParams(128),
|
|
database_name: str = "LangChainDatabase",
|
|
table_name: str = "LangChainTable",
|
|
drop_old: Optional[bool] = False,
|
|
):
|
|
pymochow = guard_import("pymochow")
|
|
configuration = guard_import("pymochow.configuration")
|
|
auth = guard_import("pymochow.auth.bce_credentials")
|
|
self.mochowtable = guard_import("pymochow.model.table")
|
|
self.mochowenum = guard_import("pymochow.model.enum")
|
|
self.embedding_func = embedding
|
|
self.table_params = table_params
|
|
config = configuration.Configuration(
|
|
credentials=auth.BceCredentials(
|
|
connection_params.account, connection_params.api_key
|
|
),
|
|
endpoint=connection_params.endpoint,
|
|
connection_timeout_in_mills=connection_params.connection_timeout_in_mills,
|
|
)
|
|
self.vdb_client = pymochow.MochowClient(config)
|
|
db_list = self.vdb_client.list_databases()
|
|
db_exist: bool = False
|
|
for db in db_list:
|
|
if database_name == db.database_name:
|
|
db_exist = True
|
|
break
|
|
if db_exist:
|
|
self.database = self.vdb_client.database(database_name)
|
|
else:
|
|
self.database = self.vdb_client.create_database(database_name)
|
|
try:
|
|
self.table = self.database.describe_table(table_name)
|
|
if drop_old:
|
|
self.database.drop_table(table_name)
|
|
self._create_table(table_name)
|
|
except pymochow.exception.ServerError:
|
|
self._create_table(table_name)
|
|
|
|
def _create_table(self, table_name: str) -> None:
|
|
schema = guard_import("pymochow.model.schema")
|
|
index_type = None
|
|
for k, v in self.mochowenum.IndexType.__members__.items():
|
|
if k == self.table_params.index_type:
|
|
index_type = v
|
|
if index_type is None:
|
|
raise ValueError("unsupported index_type")
|
|
metric_type = None
|
|
for k, v in self.mochowenum.MetricType.__members__.items():
|
|
if k == self.table_params.metric_type:
|
|
metric_type = v
|
|
if metric_type is None:
|
|
raise ValueError("unsupported metric_type")
|
|
if self.table_params.params is None:
|
|
params = schema.HNSWParams(m=16, efconstruction=200)
|
|
else:
|
|
params = schema.HNSWParams(
|
|
m=self.table_params.params.get("M", 16),
|
|
efconstruction=self.table_params.params.get("efConstruction", 200),
|
|
)
|
|
fields = []
|
|
fields.append(
|
|
schema.Field(
|
|
self.field_id,
|
|
self.mochowenum.FieldType.STRING,
|
|
primary_key=True,
|
|
partition_key=True,
|
|
auto_increment=False,
|
|
not_null=True,
|
|
)
|
|
)
|
|
fields.append(
|
|
schema.Field(
|
|
self.field_vector,
|
|
self.mochowenum.FieldType.FLOAT_VECTOR,
|
|
dimension=self.table_params.dimension,
|
|
not_null=True,
|
|
)
|
|
)
|
|
fields.append(schema.Field(self.field_text, self.mochowenum.FieldType.STRING))
|
|
fields.append(
|
|
schema.Field(self.field_metadata, self.mochowenum.FieldType.STRING)
|
|
)
|
|
indexes = []
|
|
indexes.append(
|
|
schema.VectorIndex(
|
|
index_name=self.index_vector,
|
|
index_type=index_type,
|
|
field=self.field_vector,
|
|
metric_type=metric_type,
|
|
params=params,
|
|
)
|
|
)
|
|
|
|
self.table = self.database.create_table(
|
|
table_name=table_name,
|
|
replication=self.table_params.replication,
|
|
partition=self.mochowtable.Partition(
|
|
partition_num=self.table_params.partition
|
|
),
|
|
schema=schema.Schema(fields=fields, indexes=indexes),
|
|
)
|
|
|
|
while True:
|
|
time.sleep(1)
|
|
table = self.database.describe_table(table_name)
|
|
if table.state == self.mochowenum.TableState.NORMAL:
|
|
break
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
return self.embedding_func
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
connection_params: Optional[ConnectionParams] = None,
|
|
table_params: Optional[TableParams] = None,
|
|
database_name: str = "LangChainDatabase",
|
|
table_name: str = "LangChainTable",
|
|
drop_old: Optional[bool] = False,
|
|
**kwargs: Any,
|
|
) -> BaiduVectorDB:
|
|
"""Create a table, indexes it with HNSW, and insert data."""
|
|
if len(texts) == 0:
|
|
raise ValueError("texts is empty")
|
|
if connection_params is None:
|
|
raise ValueError("connection_params is empty")
|
|
try:
|
|
embeddings = embedding.embed_documents(texts[0:1])
|
|
except NotImplementedError:
|
|
embeddings = [embedding.embed_query(texts[0])]
|
|
dimension = len(embeddings[0])
|
|
if table_params is None:
|
|
table_params = TableParams(dimension=dimension)
|
|
else:
|
|
table_params.dimension = dimension
|
|
vector_db = cls(
|
|
embedding=embedding,
|
|
connection_params=connection_params,
|
|
table_params=table_params,
|
|
database_name=database_name,
|
|
table_name=table_name,
|
|
drop_old=drop_old,
|
|
)
|
|
vector_db.add_texts(texts=texts, metadatas=metadatas)
|
|
return vector_db
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
batch_size: int = 1000,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Insert text data into Baidu VectorDB."""
|
|
texts = list(texts)
|
|
try:
|
|
embeddings = self.embedding_func.embed_documents(texts)
|
|
except NotImplementedError:
|
|
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
|
if len(embeddings) == 0:
|
|
logger.debug("Nothing to insert, skipping.")
|
|
return []
|
|
pks: list[str] = []
|
|
total_count = len(embeddings)
|
|
for start in range(0, total_count, batch_size):
|
|
# Grab end index
|
|
rows = []
|
|
end = min(start + batch_size, total_count)
|
|
for id in range(start, end, 1):
|
|
metadata = "{}"
|
|
if metadatas is not None:
|
|
metadata = json.dumps(metadatas[id])
|
|
row = self.mochowtable.Row(
|
|
id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
|
|
vector=[float(num) for num in embeddings[id]],
|
|
text=texts[id],
|
|
metadata=metadata,
|
|
)
|
|
rows.append(row)
|
|
pks.append(str(id))
|
|
self.table.upsert(rows=rows)
|
|
# need rebuild vindex after upsert
|
|
self.table.rebuild_index(self.index_vector)
|
|
while True:
|
|
time.sleep(2)
|
|
index = self.table.describe_index(self.index_vector)
|
|
if index.state == self.mochowenum.IndexState.NORMAL:
|
|
break
|
|
return pks
|
|
|
|
def similarity_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a similarity search against the query string."""
|
|
res = self.similarity_search_with_score(
|
|
query=query, k=k, param=param, expr=expr, **kwargs
|
|
)
|
|
return [doc for doc, _ in res]
|
|
|
|
def similarity_search_with_score(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a search on a query string and return results with score."""
|
|
# Embed the query text.
|
|
embedding = self.embedding_func.embed_query(query)
|
|
res = self._similarity_search_with_score(
|
|
embedding=embedding, k=k, param=param, expr=expr, **kwargs
|
|
)
|
|
return res
|
|
|
|
def similarity_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a similarity search against the query string."""
|
|
res = self._similarity_search_with_score(
|
|
embedding=embedding, k=k, param=param, expr=expr, **kwargs
|
|
)
|
|
return [doc for doc, _ in res]
|
|
|
|
def _similarity_search_with_score(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a search on a query string and return results with score."""
|
|
ef = 10 if param is None else param.get("ef", 10)
|
|
|
|
anns = self.mochowtable.AnnSearch(
|
|
vector_field=self.field_vector,
|
|
vector_floats=[float(num) for num in embedding],
|
|
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
|
|
filter=expr,
|
|
)
|
|
res = self.table.search(anns=anns)
|
|
|
|
rows = [[item] for item in res.rows]
|
|
# Organize results.
|
|
ret: List[Tuple[Document, float]] = []
|
|
if rows is None or len(rows) == 0:
|
|
return ret
|
|
for row in rows:
|
|
for result in row:
|
|
row_data = result.get("row", {})
|
|
meta = row_data.get(self.field_metadata)
|
|
if meta is not None:
|
|
meta = json.loads(meta)
|
|
doc = Document(
|
|
page_content=row_data.get(self.field_text), metadata=meta
|
|
)
|
|
pair = (doc, result.get("score", 0.0))
|
|
ret.append(pair)
|
|
return ret
|
|
|
|
def max_marginal_relevance_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
fetch_k: int = 20,
|
|
lambda_mult: float = 0.5,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a search and return results that are reordered by MMR."""
|
|
embedding = self.embedding_func.embed_query(query)
|
|
return self._max_marginal_relevance_search(
|
|
embedding=embedding,
|
|
k=k,
|
|
fetch_k=fetch_k,
|
|
lambda_mult=lambda_mult,
|
|
param=param,
|
|
expr=expr,
|
|
**kwargs,
|
|
)
|
|
|
|
def _max_marginal_relevance_search(
|
|
self,
|
|
embedding: list[float],
|
|
k: int = 4,
|
|
fetch_k: int = 20,
|
|
lambda_mult: float = 0.5,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Perform a search and return results that are reordered by MMR."""
|
|
ef = 10 if param is None else param.get("ef", 10)
|
|
anns = self.mochowtable.AnnSearch(
|
|
vector_field=self.field_vector,
|
|
vector_floats=[float(num) for num in embedding],
|
|
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
|
|
filter=expr,
|
|
)
|
|
res = self.table.search(anns=anns, retrieve_vector=True)
|
|
|
|
# Organize results.
|
|
documents: List[Document] = []
|
|
ordered_result_embeddings = []
|
|
rows = [[item] for item in res.rows]
|
|
if rows is None or len(rows) == 0:
|
|
return documents
|
|
for row in rows:
|
|
for result in row:
|
|
row_data = result.get("row", {})
|
|
meta = row_data.get(self.field_metadata)
|
|
if meta is not None:
|
|
meta = json.loads(meta)
|
|
doc = Document(
|
|
page_content=row_data.get(self.field_text), metadata=meta
|
|
)
|
|
documents.append(doc)
|
|
ordered_result_embeddings.append(row_data.get(self.field_vector))
|
|
# Get the new order of results.
|
|
new_ordering = maximal_marginal_relevance(
|
|
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
|
)
|
|
# Reorder the values and return.
|
|
ret = []
|
|
for x in new_ordering:
|
|
# Function can return -1 index
|
|
if x == -1:
|
|
break
|
|
else:
|
|
ret.append(documents[x])
|
|
return ret
|