You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/vectorstores/hippo.py

678 lines
26 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from transwarp_hippo_api.hippo_client import HippoClient
# Default connection
DEFAULT_HIPPO_CONNECTION = {
"host": "localhost",
"port": "7788",
"username": "admin",
"password": "admin",
}
logger = logging.getLogger(__name__)
class Hippo(VectorStore):
"""`Hippo` vector store.
You need to install `hippo-api` and run Hippo.
Please visit our official website for how to run a Hippo instance:
https://www.transwarp.cn/starwarp
Args:
embedding_function (Embeddings): Function used to embed the text.
table_name (str): Which Hippo table to use. Defaults to
"test".
database_name (str): Which Hippo database to use. Defaults to
"default".
number_of_shards (int): The number of shards for the Hippo table.Defaults to
1.
number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
1.
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
index_params (Optional[dict]): Which index params to use. Defaults to
IVF_FLAT.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
primary_field (str): Name of the primary key field. Defaults to "pk".
text_field (str): Name of the text field. Defaults to "text".
vector_field (str): Name of the vector field. Defaults to "vector".
The connection args used for this class comes in the form of a dict,
here are a few of the options:
host (str): The host of Hippo instance. Default at "localhost".
port (str/int): The port of Hippo instance. Default at 7788.
user (str): Use which user to connect to Hippo instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
Example:
.. code-block:: python
from langchain_community.vectorstores import Hippo
from langchain_community.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a hippo instance on localhost
vector_store = Hippo.from_documents(
docs,
embedding=embeddings,
table_name="langchain_test",
connection_args=HIPPO_CONNECTION
)
Raises:
ValueError: If the hippo-api python package is not installed.
"""
def __init__(
self,
embedding_function: Embeddings,
table_name: str = "test",
database_name: str = "default",
number_of_shards: int = 1,
number_of_replicas: int = 1,
connection_args: Optional[Dict[str, Any]] = None,
index_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
self.number_of_shards = number_of_shards
self.number_of_replicas = number_of_replicas
self.embedding_func = embedding_function
self.table_name = table_name
self.database_name = database_name
self.index_params = index_params
# In order for a collection to be compatible,
# 'pk' should be an auto-increment primary key and string
self._primary_field = "pk"
# In order for compatibility, the text field will need to be called "text"
self._text_field = "text"
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = "vector"
self.fields: List[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_HIPPO_CONNECTION
self.hc = self._create_connection_alias(connection_args)
self.col: Any = None
# If the collection exists, delete it
try:
if (
self.hc.check_table_exists(self.table_name, self.database_name)
and drop_old
):
self.hc.delete_table(self.table_name, self.database_name)
except Exception as e:
logging.error(
f"An error occurred while deleting the table " f"{self.table_name}: {e}"
)
raise
try:
if self.hc.check_table_exists(self.table_name, self.database_name):
self.col = self.hc.get_table(self.table_name, self.database_name)
except Exception as e:
logging.error(
f"An error occurred while getting the table " f"{self.table_name}: {e}"
)
raise
# Initialize the vector database
self._get_env()
def _create_connection_alias(self, connection_args: dict) -> HippoClient:
"""Create the connection to the Hippo server."""
# Grab the connection arguments that are used for checking existing connection
try:
from transwarp_hippo_api.hippo_client import HippoClient
except ImportError as e:
raise ImportError(
"Unable to import transwarp_hipp_api, please install with "
"`pip install hippo-api`."
) from e
host: str = connection_args.get("host", None)
port: int = connection_args.get("port", None)
username: str = connection_args.get("username", "shiva")
password: str = connection_args.get("password", "shiva")
# Order of use is host/port, uri, address
if host is not None and port is not None:
if "," in host:
hosts = host.split(",")
given_address = ",".join([f"{h}:{port}" for h in hosts])
else:
given_address = str(host) + ":" + str(port)
else:
raise ValueError("Missing standard address type for reuse attempt")
try:
logger.info(f"create HippoClient[{given_address}]")
return HippoClient([given_address], username=username, pwd=password)
except Exception as e:
logger.error("Failed to create new connection")
raise e
def _get_env(
self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
) -> None:
logger.info("init ...")
if embeddings is not None:
logger.info("create collection")
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
def _create_collection(
self, embeddings: list, metadatas: Optional[List[dict]] = None
) -> None:
from transwarp_hippo_api.hippo_client import HippoField
from transwarp_hippo_api.hippo_type import HippoType
# Determine embedding dim
dim = len(embeddings[0])
logger.debug(f"[_create_collection] dim: {dim}")
fields = []
# Create the primary key field
fields.append(HippoField(self._primary_field, True, HippoType.STRING))
# Create the text field
fields.append(HippoField(self._text_field, False, HippoType.STRING))
# Create the vector field, supports binary or float vectors
# to The binary vector type is to be developed.
fields.append(
HippoField(
self._vector_field,
False,
HippoType.FLOAT_VECTOR,
type_params={"dimension": dim},
)
)
# to In Hippo,there is no method similar to the infer_type_data
# types, so currently all non-vector data is converted to string type.
if metadatas:
# # Create FieldSchema for each entry in metadata.
for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
if isinstance(value, list):
value_dim = len(value)
fields.append(
HippoField(
key,
False,
HippoType.FLOAT_VECTOR,
type_params={"dimension": value_dim},
)
)
else:
fields.append(HippoField(key, False, HippoType.STRING))
logger.debug(f"[_create_collection] fields: {fields}")
# Create the collection
self.hc.create_table(
name=self.table_name,
auto_id=True,
fields=fields,
database_name=self.database_name,
number_of_shards=self.number_of_shards,
number_of_replicas=self.number_of_replicas,
)
self.col = self.hc.get_table(self.table_name, self.database_name)
logger.info(
f"[_create_collection] : "
f"create table {self.table_name} in {self.database_name} successfully"
)
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from transwarp_hippo_api.hippo_client import HippoTable
if isinstance(self.col, HippoTable):
schema = self.col.schema
logger.debug(f"[_extract_fields] schema:{schema}")
for x in schema:
self.fields.append(x.name)
logger.debug(f"04 [_extract_fields] fields:{self.fields}")
# TO CAN: Translated into English, your statement would be: "Currently,
# only the field named 'vector' (the automatically created vector field)
# is checked for indexing. Indexes need to be created manually for other
# vector type columns.
def _get_index(self) -> Optional[Dict[str, Any]]:
"""Return the vector index information if it exists"""
from transwarp_hippo_api.hippo_client import HippoTable
if isinstance(self.col, HippoTable):
table_info = self.hc.get_table_info(
self.table_name, self.database_name
).get(self.table_name, {})
embedding_indexes = table_info.get("embedding_indexes", None)
if embedding_indexes is None:
return None
else:
for x in self.hc.get_table_info(self.table_name, self.database_name)[
self.table_name
]["embedding_indexes"]:
logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
if x["column"] == self._vector_field:
return x
return None
# TO Indexes can only be created for the self._vector_field field.
def _create_index(self) -> None:
"""Create a index on the collection"""
from transwarp_hippo_api.hippo_client import HippoTable
from transwarp_hippo_api.hippo_type import IndexType, MetricType
if isinstance(self.col, HippoTable) and self._get_index() is None:
if self._get_index() is None:
if self.index_params is None:
self.index_params = {
"index_name": "langchain_auto_create",
"metric_type": MetricType.L2,
"index_type": IndexType.IVF_FLAT,
"nlist": 10,
}
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params["nlist"],
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
logger.info("create index successfully")
else:
index_dict = {
"IVF_FLAT": IndexType.IVF_FLAT,
"FLAT": IndexType.FLAT,
"IVF_SQ": IndexType.IVF_SQ,
"IVF_PQ": IndexType.IVF_PQ,
"HNSW": IndexType.HNSW,
}
metric_dict = {
"ip": MetricType.IP,
"IP": MetricType.IP,
"l2": MetricType.L2,
"L2": MetricType.L2,
}
self.index_params["metric_type"] = metric_dict[
self.index_params["metric_type"]
]
if self.index_params["index_type"] == "FLAT":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif (
self.index_params["index_type"] == "IVF_FLAT"
or self.index_params["index_type"] == "IVF_SQ"
):
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params.get("nlist", 10),
nprobe=self.index_params.get("nprobe", 10),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif self.index_params["index_type"] == "IVF_PQ":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params.get("nlist", 10),
nprobe=self.index_params.get("nprobe", 10),
nbits=self.index_params.get("nbits", 8),
m=self.index_params.get("m"),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif self.index_params["index_type"] == "HNSW":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
M=self.index_params.get("M"),
ef_construction=self.index_params.get("ef_construction"),
ef_search=self.index_params.get("ef_search"),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
else:
raise ValueError(
"Index name does not match, "
"please enter the correct index name. "
"(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""
Add text to the collection.
Args:
texts: An iterable that contains the text to be added.
metadatas: An optional list of dictionaries,
each dictionary contains the metadata associated with a text.
timeout: Optional timeout, in seconds.
batch_size: The number of texts inserted in each batch, defaults to 1000.
**kwargs: Other optional parameters.
Returns:
A list of strings, containing the unique identifiers of the inserted texts.
Note:
If the collection has not yet been created,
this method will create a new collection.
"""
from transwarp_hippo_api.hippo_client import HippoTable
if not texts or all(t == "" for t in texts):
logger.debug("Nothing to insert, skipping.")
return []
texts = list(texts)
logger.debug(f"[add_texts] texts: {texts}")
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")
# 如果还没有创建collection则创建collection
if not isinstance(self.col, HippoTable):
self._get_env(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: Dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
logger.debug(f"[add_texts] metadatas:{metadatas}")
logger.debug(f"[add_texts] fields:{self.fields}")
if metadatas is not None:
for d in metadatas:
for key, value in d.items():
if key in self.fields:
insert_dict.setdefault(key, []).append(value)
logger.debug(insert_dict[self._text_field])
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
if "pk" in self.fields:
self.fields.remove("pk")
logger.debug(f"[add_texts] total_count:{total_count}")
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
try:
res = self.col.insert_rows(insert_list)
logger.info(f"05 [add_texts] insert {res}")
except Exception as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return [""]
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""
Perform a similarity search on the query string.
Args:
query (str): The text to search for.
k (int, optional): The number of results to return. Default is 4.
param (dict, optional): Specifies the search parameters for the index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): Time to wait before a timeout error.
Defaults to None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[Document]: The document results of the search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Performs a search on the query string and returns results with scores.
Args:
query (str): The text being searched.
k (int, optional): The number of results to return.
Default is 4.
param (dict): Specifies the search parameters for the index.
Default is None.
expr (str, optional): Filtering expression. Default is None.
timeout (int, optional): The waiting time before a timeout error.
Default is None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
ret = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return ret
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Performs a search on the query string and returns results with scores.
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The number of results to return.
Default is 4.
param (dict): Specifies the search parameters for the index.
Default is None.
expr (str, optional): Filtering expression. Default is None.
timeout (int, optional): The waiting time before a timeout error.
Default is None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[Tuple[Document, float]]: Resulting documents and scores.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# if param is None:
# param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
logger.debug(f"search_field:{self._vector_field}")
logger.debug(f"vectors:{[embedding]}")
logger.debug(f"output_fields:{output_fields}")
logger.debug(f"topk:{k}")
logger.debug(f"dsl:{expr}")
res = self.col.query(
search_field=self._vector_field,
vectors=[embedding],
output_fields=output_fields,
topk=k,
dsl=expr,
)
# Organize results.
logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
score_col = self._text_field + "%scores"
ret = []
count = 0
for items in zip(*[res[0][field] for field in output_fields]):
meta = {field: value for field, value in zip(output_fields, items)}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
logger.debug(
f"[similarity_search_with_score_by_vector] "
f"res[0][score_col]:{res[0][score_col]}"
)
score = res[0][score_col][count]
count += 1
ret.append((doc, score))
return ret
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
table_name: str = "test",
database_name: str = "default",
connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
index_params: Optional[Dict[Any, Any]] = None,
search_params: Optional[Dict[str, Any]] = None,
drop_old: bool = False,
**kwargs: Any,
) -> "Hippo":
"""
Creates an instance of the VST class from the given texts.
Args:
texts (List[str]): List of texts to be added.
embedding (Embeddings): Embedding model for the texts.
metadatas (List[dict], optional):
List of metadata dictionaries for each text.Defaults to None.
table_name (str): Name of the table. Defaults to "test".
database_name (str): Name of the database. Defaults to "default".
connection_args (dict[str, Any]): Connection parameters.
Defaults to DEFAULT_HIPPO_CONNECTION.
index_params (dict): Indexing parameters. Defaults to None.
search_params (dict): Search parameters. Defaults to an empty dictionary.
drop_old (bool): Whether to drop the old collection. Defaults to False.
kwargs: Other arguments.
Returns:
Hippo: An instance of the VST class.
"""
if search_params is None:
search_params = {}
logger.info("00 [from_texts] init the class of Hippo")
vector_db = cls(
embedding_function=embedding,
table_name=table_name,
database_name=database_name,
connection_args=connection_args,
index_params=index_params,
drop_old=drop_old,
**kwargs,
)
logger.debug(f"[from_texts] texts:{texts}")
logger.debug(f"[from_texts] metadatas:{metadatas}")
vector_db.add_texts(texts=texts, metadatas=metadatas)
return vector_db