mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
678 lines
26 KiB
Python
678 lines
26 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import logging
|
||
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||
|
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.vectorstores import VectorStore
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from transwarp_hippo_api.hippo_client import HippoClient
|
||
|
|
||
|
# Default connection
|
||
|
DEFAULT_HIPPO_CONNECTION = {
|
||
|
"host": "localhost",
|
||
|
"port": "7788",
|
||
|
"username": "admin",
|
||
|
"password": "admin",
|
||
|
}
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class Hippo(VectorStore):
|
||
|
"""`Hippo` vector store.
|
||
|
|
||
|
You need to install `hippo-api` and run Hippo.
|
||
|
|
||
|
Please visit our official website for how to run a Hippo instance:
|
||
|
https://www.transwarp.cn/starwarp
|
||
|
|
||
|
Args:
|
||
|
embedding_function (Embeddings): Function used to embed the text.
|
||
|
table_name (str): Which Hippo table to use. Defaults to
|
||
|
"test".
|
||
|
database_name (str): Which Hippo database to use. Defaults to
|
||
|
"default".
|
||
|
number_of_shards (int): The number of shards for the Hippo table.Defaults to
|
||
|
1.
|
||
|
number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
|
||
|
1.
|
||
|
connection_args (Optional[dict[str, any]]): The connection args used for
|
||
|
this class comes in the form of a dict.
|
||
|
index_params (Optional[dict]): Which index params to use. Defaults to
|
||
|
IVF_FLAT.
|
||
|
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||
|
to False.
|
||
|
primary_field (str): Name of the primary key field. Defaults to "pk".
|
||
|
text_field (str): Name of the text field. Defaults to "text".
|
||
|
vector_field (str): Name of the vector field. Defaults to "vector".
|
||
|
|
||
|
The connection args used for this class comes in the form of a dict,
|
||
|
here are a few of the options:
|
||
|
host (str): The host of Hippo instance. Default at "localhost".
|
||
|
port (str/int): The port of Hippo instance. Default at 7788.
|
||
|
user (str): Use which user to connect to Hippo instance. If user and
|
||
|
password are provided, we will add related header in every RPC call.
|
||
|
password (str): Required when user is provided. The password
|
||
|
corresponding to the user.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.vectorstores import Hippo
|
||
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||
|
|
||
|
embedding = OpenAIEmbeddings()
|
||
|
# Connect to a hippo instance on localhost
|
||
|
vector_store = Hippo.from_documents(
|
||
|
docs,
|
||
|
embedding=embeddings,
|
||
|
table_name="langchain_test",
|
||
|
connection_args=HIPPO_CONNECTION
|
||
|
)
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the hippo-api python package is not installed.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embedding_function: Embeddings,
|
||
|
table_name: str = "test",
|
||
|
database_name: str = "default",
|
||
|
number_of_shards: int = 1,
|
||
|
number_of_replicas: int = 1,
|
||
|
connection_args: Optional[Dict[str, Any]] = None,
|
||
|
index_params: Optional[dict] = None,
|
||
|
drop_old: Optional[bool] = False,
|
||
|
):
|
||
|
self.number_of_shards = number_of_shards
|
||
|
self.number_of_replicas = number_of_replicas
|
||
|
self.embedding_func = embedding_function
|
||
|
self.table_name = table_name
|
||
|
self.database_name = database_name
|
||
|
self.index_params = index_params
|
||
|
|
||
|
# In order for a collection to be compatible,
|
||
|
# 'pk' should be an auto-increment primary key and string
|
||
|
self._primary_field = "pk"
|
||
|
# In order for compatibility, the text field will need to be called "text"
|
||
|
self._text_field = "text"
|
||
|
# In order for compatibility, the vector field needs to be called "vector"
|
||
|
self._vector_field = "vector"
|
||
|
self.fields: List[str] = []
|
||
|
# Create the connection to the server
|
||
|
if connection_args is None:
|
||
|
connection_args = DEFAULT_HIPPO_CONNECTION
|
||
|
self.hc = self._create_connection_alias(connection_args)
|
||
|
self.col: Any = None
|
||
|
|
||
|
# If the collection exists, delete it
|
||
|
try:
|
||
|
if (
|
||
|
self.hc.check_table_exists(self.table_name, self.database_name)
|
||
|
and drop_old
|
||
|
):
|
||
|
self.hc.delete_table(self.table_name, self.database_name)
|
||
|
except Exception as e:
|
||
|
logging.error(
|
||
|
f"An error occurred while deleting the table " f"{self.table_name}: {e}"
|
||
|
)
|
||
|
raise
|
||
|
|
||
|
try:
|
||
|
if self.hc.check_table_exists(self.table_name, self.database_name):
|
||
|
self.col = self.hc.get_table(self.table_name, self.database_name)
|
||
|
except Exception as e:
|
||
|
logging.error(
|
||
|
f"An error occurred while getting the table " f"{self.table_name}: {e}"
|
||
|
)
|
||
|
raise
|
||
|
|
||
|
# Initialize the vector database
|
||
|
self._get_env()
|
||
|
|
||
|
def _create_connection_alias(self, connection_args: dict) -> HippoClient:
|
||
|
"""Create the connection to the Hippo server."""
|
||
|
# Grab the connection arguments that are used for checking existing connection
|
||
|
try:
|
||
|
from transwarp_hippo_api.hippo_client import HippoClient
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Unable to import transwarp_hipp_api, please install with "
|
||
|
"`pip install hippo-api`."
|
||
|
) from e
|
||
|
|
||
|
host: str = connection_args.get("host", None)
|
||
|
port: int = connection_args.get("port", None)
|
||
|
username: str = connection_args.get("username", "shiva")
|
||
|
password: str = connection_args.get("password", "shiva")
|
||
|
|
||
|
# Order of use is host/port, uri, address
|
||
|
if host is not None and port is not None:
|
||
|
if "," in host:
|
||
|
hosts = host.split(",")
|
||
|
given_address = ",".join([f"{h}:{port}" for h in hosts])
|
||
|
else:
|
||
|
given_address = str(host) + ":" + str(port)
|
||
|
else:
|
||
|
raise ValueError("Missing standard address type for reuse attempt")
|
||
|
|
||
|
try:
|
||
|
logger.info(f"create HippoClient[{given_address}]")
|
||
|
return HippoClient([given_address], username=username, pwd=password)
|
||
|
except Exception as e:
|
||
|
logger.error("Failed to create new connection")
|
||
|
raise e
|
||
|
|
||
|
def _get_env(
|
||
|
self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
|
||
|
) -> None:
|
||
|
logger.info("init ...")
|
||
|
if embeddings is not None:
|
||
|
logger.info("create collection")
|
||
|
self._create_collection(embeddings, metadatas)
|
||
|
self._extract_fields()
|
||
|
self._create_index()
|
||
|
|
||
|
def _create_collection(
|
||
|
self, embeddings: list, metadatas: Optional[List[dict]] = None
|
||
|
) -> None:
|
||
|
from transwarp_hippo_api.hippo_client import HippoField
|
||
|
from transwarp_hippo_api.hippo_type import HippoType
|
||
|
|
||
|
# Determine embedding dim
|
||
|
dim = len(embeddings[0])
|
||
|
logger.debug(f"[_create_collection] dim: {dim}")
|
||
|
fields = []
|
||
|
|
||
|
# Create the primary key field
|
||
|
fields.append(HippoField(self._primary_field, True, HippoType.STRING))
|
||
|
|
||
|
# Create the text field
|
||
|
|
||
|
fields.append(HippoField(self._text_field, False, HippoType.STRING))
|
||
|
|
||
|
# Create the vector field, supports binary or float vectors
|
||
|
# to The binary vector type is to be developed.
|
||
|
fields.append(
|
||
|
HippoField(
|
||
|
self._vector_field,
|
||
|
False,
|
||
|
HippoType.FLOAT_VECTOR,
|
||
|
type_params={"dimension": dim},
|
||
|
)
|
||
|
)
|
||
|
# to In Hippo,there is no method similar to the infer_type_data
|
||
|
# types, so currently all non-vector data is converted to string type.
|
||
|
|
||
|
if metadatas:
|
||
|
# # Create FieldSchema for each entry in metadata.
|
||
|
for key, value in metadatas[0].items():
|
||
|
# # Infer the corresponding datatype of the metadata
|
||
|
if isinstance(value, list):
|
||
|
value_dim = len(value)
|
||
|
fields.append(
|
||
|
HippoField(
|
||
|
key,
|
||
|
False,
|
||
|
HippoType.FLOAT_VECTOR,
|
||
|
type_params={"dimension": value_dim},
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
fields.append(HippoField(key, False, HippoType.STRING))
|
||
|
|
||
|
logger.debug(f"[_create_collection] fields: {fields}")
|
||
|
|
||
|
# Create the collection
|
||
|
self.hc.create_table(
|
||
|
name=self.table_name,
|
||
|
auto_id=True,
|
||
|
fields=fields,
|
||
|
database_name=self.database_name,
|
||
|
number_of_shards=self.number_of_shards,
|
||
|
number_of_replicas=self.number_of_replicas,
|
||
|
)
|
||
|
self.col = self.hc.get_table(self.table_name, self.database_name)
|
||
|
logger.info(
|
||
|
f"[_create_collection] : "
|
||
|
f"create table {self.table_name} in {self.database_name} successfully"
|
||
|
)
|
||
|
|
||
|
def _extract_fields(self) -> None:
|
||
|
"""Grab the existing fields from the Collection"""
|
||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||
|
|
||
|
if isinstance(self.col, HippoTable):
|
||
|
schema = self.col.schema
|
||
|
logger.debug(f"[_extract_fields] schema:{schema}")
|
||
|
for x in schema:
|
||
|
self.fields.append(x.name)
|
||
|
logger.debug(f"04 [_extract_fields] fields:{self.fields}")
|
||
|
|
||
|
# TO CAN: Translated into English, your statement would be: "Currently,
|
||
|
# only the field named 'vector' (the automatically created vector field)
|
||
|
# is checked for indexing. Indexes need to be created manually for other
|
||
|
# vector type columns.
|
||
|
def _get_index(self) -> Optional[Dict[str, Any]]:
|
||
|
"""Return the vector index information if it exists"""
|
||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||
|
|
||
|
if isinstance(self.col, HippoTable):
|
||
|
table_info = self.hc.get_table_info(
|
||
|
self.table_name, self.database_name
|
||
|
).get(self.table_name, {})
|
||
|
embedding_indexes = table_info.get("embedding_indexes", None)
|
||
|
if embedding_indexes is None:
|
||
|
return None
|
||
|
else:
|
||
|
for x in self.hc.get_table_info(self.table_name, self.database_name)[
|
||
|
self.table_name
|
||
|
]["embedding_indexes"]:
|
||
|
logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
|
||
|
if x["column"] == self._vector_field:
|
||
|
return x
|
||
|
return None
|
||
|
|
||
|
# TO Indexes can only be created for the self._vector_field field.
|
||
|
def _create_index(self) -> None:
|
||
|
"""Create a index on the collection"""
|
||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||
|
from transwarp_hippo_api.hippo_type import IndexType, MetricType
|
||
|
|
||
|
if isinstance(self.col, HippoTable) and self._get_index() is None:
|
||
|
if self._get_index() is None:
|
||
|
if self.index_params is None:
|
||
|
self.index_params = {
|
||
|
"index_name": "langchain_auto_create",
|
||
|
"metric_type": MetricType.L2,
|
||
|
"index_type": IndexType.IVF_FLAT,
|
||
|
"nlist": 10,
|
||
|
}
|
||
|
|
||
|
self.col.create_index(
|
||
|
self._vector_field,
|
||
|
self.index_params["index_name"],
|
||
|
self.index_params["index_type"],
|
||
|
self.index_params["metric_type"],
|
||
|
nlist=self.index_params["nlist"],
|
||
|
)
|
||
|
logger.debug(
|
||
|
self.col.activate_index(self.index_params["index_name"])
|
||
|
)
|
||
|
logger.info("create index successfully")
|
||
|
else:
|
||
|
index_dict = {
|
||
|
"IVF_FLAT": IndexType.IVF_FLAT,
|
||
|
"FLAT": IndexType.FLAT,
|
||
|
"IVF_SQ": IndexType.IVF_SQ,
|
||
|
"IVF_PQ": IndexType.IVF_PQ,
|
||
|
"HNSW": IndexType.HNSW,
|
||
|
}
|
||
|
|
||
|
metric_dict = {
|
||
|
"ip": MetricType.IP,
|
||
|
"IP": MetricType.IP,
|
||
|
"l2": MetricType.L2,
|
||
|
"L2": MetricType.L2,
|
||
|
}
|
||
|
self.index_params["metric_type"] = metric_dict[
|
||
|
self.index_params["metric_type"]
|
||
|
]
|
||
|
|
||
|
if self.index_params["index_type"] == "FLAT":
|
||
|
self.index_params["index_type"] = index_dict[
|
||
|
self.index_params["index_type"]
|
||
|
]
|
||
|
self.col.create_index(
|
||
|
self._vector_field,
|
||
|
self.index_params["index_name"],
|
||
|
self.index_params["index_type"],
|
||
|
self.index_params["metric_type"],
|
||
|
)
|
||
|
logger.debug(
|
||
|
self.col.activate_index(self.index_params["index_name"])
|
||
|
)
|
||
|
elif (
|
||
|
self.index_params["index_type"] == "IVF_FLAT"
|
||
|
or self.index_params["index_type"] == "IVF_SQ"
|
||
|
):
|
||
|
self.index_params["index_type"] = index_dict[
|
||
|
self.index_params["index_type"]
|
||
|
]
|
||
|
self.col.create_index(
|
||
|
self._vector_field,
|
||
|
self.index_params["index_name"],
|
||
|
self.index_params["index_type"],
|
||
|
self.index_params["metric_type"],
|
||
|
nlist=self.index_params.get("nlist", 10),
|
||
|
nprobe=self.index_params.get("nprobe", 10),
|
||
|
)
|
||
|
logger.debug(
|
||
|
self.col.activate_index(self.index_params["index_name"])
|
||
|
)
|
||
|
elif self.index_params["index_type"] == "IVF_PQ":
|
||
|
self.index_params["index_type"] = index_dict[
|
||
|
self.index_params["index_type"]
|
||
|
]
|
||
|
self.col.create_index(
|
||
|
self._vector_field,
|
||
|
self.index_params["index_name"],
|
||
|
self.index_params["index_type"],
|
||
|
self.index_params["metric_type"],
|
||
|
nlist=self.index_params.get("nlist", 10),
|
||
|
nprobe=self.index_params.get("nprobe", 10),
|
||
|
nbits=self.index_params.get("nbits", 8),
|
||
|
m=self.index_params.get("m"),
|
||
|
)
|
||
|
logger.debug(
|
||
|
self.col.activate_index(self.index_params["index_name"])
|
||
|
)
|
||
|
elif self.index_params["index_type"] == "HNSW":
|
||
|
self.index_params["index_type"] = index_dict[
|
||
|
self.index_params["index_type"]
|
||
|
]
|
||
|
self.col.create_index(
|
||
|
self._vector_field,
|
||
|
self.index_params["index_name"],
|
||
|
self.index_params["index_type"],
|
||
|
self.index_params["metric_type"],
|
||
|
M=self.index_params.get("M"),
|
||
|
ef_construction=self.index_params.get("ef_construction"),
|
||
|
ef_search=self.index_params.get("ef_search"),
|
||
|
)
|
||
|
logger.debug(
|
||
|
self.col.activate_index(self.index_params["index_name"])
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Index name does not match, "
|
||
|
"please enter the correct index name. "
|
||
|
"(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
|
||
|
)
|
||
|
|
||
|
def add_texts(
|
||
|
self,
|
||
|
texts: Iterable[str],
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
timeout: Optional[int] = None,
|
||
|
batch_size: int = 1000,
|
||
|
**kwargs: Any,
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Add text to the collection.
|
||
|
|
||
|
Args:
|
||
|
texts: An iterable that contains the text to be added.
|
||
|
metadatas: An optional list of dictionaries,
|
||
|
each dictionary contains the metadata associated with a text.
|
||
|
timeout: Optional timeout, in seconds.
|
||
|
batch_size: The number of texts inserted in each batch, defaults to 1000.
|
||
|
**kwargs: Other optional parameters.
|
||
|
|
||
|
Returns:
|
||
|
A list of strings, containing the unique identifiers of the inserted texts.
|
||
|
|
||
|
Note:
|
||
|
If the collection has not yet been created,
|
||
|
this method will create a new collection.
|
||
|
"""
|
||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||
|
|
||
|
if not texts or all(t == "" for t in texts):
|
||
|
logger.debug("Nothing to insert, skipping.")
|
||
|
return []
|
||
|
texts = list(texts)
|
||
|
|
||
|
logger.debug(f"[add_texts] texts: {texts}")
|
||
|
|
||
|
try:
|
||
|
embeddings = self.embedding_func.embed_documents(texts)
|
||
|
except NotImplementedError:
|
||
|
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||
|
|
||
|
if len(embeddings) == 0:
|
||
|
logger.debug("Nothing to insert, skipping.")
|
||
|
return []
|
||
|
|
||
|
logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")
|
||
|
|
||
|
# 如果还没有创建collection则创建collection
|
||
|
if not isinstance(self.col, HippoTable):
|
||
|
self._get_env(embeddings, metadatas)
|
||
|
|
||
|
# Dict to hold all insert columns
|
||
|
insert_dict: Dict[str, list] = {
|
||
|
self._text_field: texts,
|
||
|
self._vector_field: embeddings,
|
||
|
}
|
||
|
logger.debug(f"[add_texts] metadatas:{metadatas}")
|
||
|
logger.debug(f"[add_texts] fields:{self.fields}")
|
||
|
if metadatas is not None:
|
||
|
for d in metadatas:
|
||
|
for key, value in d.items():
|
||
|
if key in self.fields:
|
||
|
insert_dict.setdefault(key, []).append(value)
|
||
|
|
||
|
logger.debug(insert_dict[self._text_field])
|
||
|
|
||
|
# Total insert count
|
||
|
vectors: list = insert_dict[self._vector_field]
|
||
|
total_count = len(vectors)
|
||
|
|
||
|
if "pk" in self.fields:
|
||
|
self.fields.remove("pk")
|
||
|
|
||
|
logger.debug(f"[add_texts] total_count:{total_count}")
|
||
|
for i in range(0, total_count, batch_size):
|
||
|
# Grab end index
|
||
|
end = min(i + batch_size, total_count)
|
||
|
# Convert dict to list of lists batch for insertion
|
||
|
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||
|
try:
|
||
|
res = self.col.insert_rows(insert_list)
|
||
|
logger.info(f"05 [add_texts] insert {res}")
|
||
|
except Exception as e:
|
||
|
logger.error(
|
||
|
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||
|
)
|
||
|
raise e
|
||
|
return [""]
|
||
|
|
||
|
def similarity_search(
|
||
|
self,
|
||
|
query: str,
|
||
|
k: int = 4,
|
||
|
param: Optional[dict] = None,
|
||
|
expr: Optional[str] = None,
|
||
|
timeout: Optional[int] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
"""
|
||
|
Perform a similarity search on the query string.
|
||
|
|
||
|
Args:
|
||
|
query (str): The text to search for.
|
||
|
k (int, optional): The number of results to return. Default is 4.
|
||
|
param (dict, optional): Specifies the search parameters for the index.
|
||
|
Defaults to None.
|
||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||
|
timeout (int, optional): Time to wait before a timeout error.
|
||
|
Defaults to None.
|
||
|
kwargs: Keyword arguments for Collection.search().
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: The document results of the search.
|
||
|
"""
|
||
|
|
||
|
if self.col is None:
|
||
|
logger.debug("No existing collection to search.")
|
||
|
return []
|
||
|
res = self.similarity_search_with_score(
|
||
|
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||
|
)
|
||
|
return [doc for doc, _ in res]
|
||
|
|
||
|
def similarity_search_with_score(
|
||
|
self,
|
||
|
query: str,
|
||
|
k: int = 4,
|
||
|
param: Optional[dict] = None,
|
||
|
expr: Optional[str] = None,
|
||
|
timeout: Optional[int] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Tuple[Document, float]]:
|
||
|
"""
|
||
|
Performs a search on the query string and returns results with scores.
|
||
|
|
||
|
Args:
|
||
|
query (str): The text being searched.
|
||
|
k (int, optional): The number of results to return.
|
||
|
Default is 4.
|
||
|
param (dict): Specifies the search parameters for the index.
|
||
|
Default is None.
|
||
|
expr (str, optional): Filtering expression. Default is None.
|
||
|
timeout (int, optional): The waiting time before a timeout error.
|
||
|
Default is None.
|
||
|
kwargs: Keyword arguments for Collection.search().
|
||
|
|
||
|
Returns:
|
||
|
List[float], List[Tuple[Document, any, any]]:
|
||
|
"""
|
||
|
|
||
|
if self.col is None:
|
||
|
logger.debug("No existing collection to search.")
|
||
|
return []
|
||
|
|
||
|
# Embed the query text.
|
||
|
embedding = self.embedding_func.embed_query(query)
|
||
|
|
||
|
ret = self.similarity_search_with_score_by_vector(
|
||
|
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||
|
)
|
||
|
return ret
|
||
|
|
||
|
def similarity_search_with_score_by_vector(
|
||
|
self,
|
||
|
embedding: List[float],
|
||
|
k: int = 4,
|
||
|
param: Optional[dict] = None,
|
||
|
expr: Optional[str] = None,
|
||
|
timeout: Optional[int] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Tuple[Document, float]]:
|
||
|
"""
|
||
|
Performs a search on the query string and returns results with scores.
|
||
|
|
||
|
Args:
|
||
|
embedding (List[float]): The embedding vector being searched.
|
||
|
k (int, optional): The number of results to return.
|
||
|
Default is 4.
|
||
|
param (dict): Specifies the search parameters for the index.
|
||
|
Default is None.
|
||
|
expr (str, optional): Filtering expression. Default is None.
|
||
|
timeout (int, optional): The waiting time before a timeout error.
|
||
|
Default is None.
|
||
|
kwargs: Keyword arguments for Collection.search().
|
||
|
|
||
|
Returns:
|
||
|
List[Tuple[Document, float]]: Resulting documents and scores.
|
||
|
"""
|
||
|
if self.col is None:
|
||
|
logger.debug("No existing collection to search.")
|
||
|
return []
|
||
|
|
||
|
# if param is None:
|
||
|
# param = self.search_params
|
||
|
|
||
|
# Determine result metadata fields.
|
||
|
output_fields = self.fields[:]
|
||
|
output_fields.remove(self._vector_field)
|
||
|
|
||
|
# Perform the search.
|
||
|
logger.debug(f"search_field:{self._vector_field}")
|
||
|
logger.debug(f"vectors:{[embedding]}")
|
||
|
logger.debug(f"output_fields:{output_fields}")
|
||
|
logger.debug(f"topk:{k}")
|
||
|
logger.debug(f"dsl:{expr}")
|
||
|
|
||
|
res = self.col.query(
|
||
|
search_field=self._vector_field,
|
||
|
vectors=[embedding],
|
||
|
output_fields=output_fields,
|
||
|
topk=k,
|
||
|
dsl=expr,
|
||
|
)
|
||
|
# Organize results.
|
||
|
logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
|
||
|
score_col = self._text_field + "%scores"
|
||
|
ret = []
|
||
|
count = 0
|
||
|
for items in zip(*[res[0][field] for field in output_fields]):
|
||
|
meta = {field: value for field, value in zip(output_fields, items)}
|
||
|
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||
|
logger.debug(
|
||
|
f"[similarity_search_with_score_by_vector] "
|
||
|
f"res[0][score_col]:{res[0][score_col]}"
|
||
|
)
|
||
|
score = res[0][score_col][count]
|
||
|
count += 1
|
||
|
ret.append((doc, score))
|
||
|
|
||
|
return ret
|
||
|
|
||
|
@classmethod
|
||
|
def from_texts(
|
||
|
cls,
|
||
|
texts: List[str],
|
||
|
embedding: Embeddings,
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
table_name: str = "test",
|
||
|
database_name: str = "default",
|
||
|
connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
|
||
|
index_params: Optional[Dict[Any, Any]] = None,
|
||
|
search_params: Optional[Dict[str, Any]] = None,
|
||
|
drop_old: bool = False,
|
||
|
**kwargs: Any,
|
||
|
) -> "Hippo":
|
||
|
"""
|
||
|
Creates an instance of the VST class from the given texts.
|
||
|
|
||
|
Args:
|
||
|
texts (List[str]): List of texts to be added.
|
||
|
embedding (Embeddings): Embedding model for the texts.
|
||
|
metadatas (List[dict], optional):
|
||
|
List of metadata dictionaries for each text.Defaults to None.
|
||
|
table_name (str): Name of the table. Defaults to "test".
|
||
|
database_name (str): Name of the database. Defaults to "default".
|
||
|
connection_args (dict[str, Any]): Connection parameters.
|
||
|
Defaults to DEFAULT_HIPPO_CONNECTION.
|
||
|
index_params (dict): Indexing parameters. Defaults to None.
|
||
|
search_params (dict): Search parameters. Defaults to an empty dictionary.
|
||
|
drop_old (bool): Whether to drop the old collection. Defaults to False.
|
||
|
kwargs: Other arguments.
|
||
|
|
||
|
Returns:
|
||
|
Hippo: An instance of the VST class.
|
||
|
"""
|
||
|
|
||
|
if search_params is None:
|
||
|
search_params = {}
|
||
|
logger.info("00 [from_texts] init the class of Hippo")
|
||
|
vector_db = cls(
|
||
|
embedding_function=embedding,
|
||
|
table_name=table_name,
|
||
|
database_name=database_name,
|
||
|
connection_args=connection_args,
|
||
|
index_params=index_params,
|
||
|
drop_old=drop_old,
|
||
|
**kwargs,
|
||
|
)
|
||
|
logger.debug(f"[from_texts] texts:{texts}")
|
||
|
logger.debug(f"[from_texts] metadatas:{metadatas}")
|
||
|
vector_db.add_texts(texts=texts, metadatas=metadatas)
|
||
|
return vector_db
|