Add a new vector store hippo for langchain #11763 (#12412)

#11763

---------

Co-authored-by: TranswarpHippo <hippo.0.assistant@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12550/head
chocolate4 10 months ago committed by GitHub
parent 342d6c7ab6
commit 92bf40a921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,431 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## Hippo\n",
"\n",
">[Hippo](https://www.transwarp.cn/starwarp) Please visit our official website for how to run a Hippo instance and\n",
"how to use functionality related to the Hippo vector database\n",
"\n",
"## Getting Started\n",
"\n",
"The only prerequisite here is an API key from the OpenAI website. Make sure you have already started a Hippo instance."
],
"metadata": {
"collapsed": false
},
"id": "357f24224a8e818f"
},
{
"cell_type": "markdown",
"source": [
"## Installing Dependencies\n",
"\n",
"Initially, we require the installation of certain dependencies, such as OpenAI, Langchain, and Hippo-API. Please note, you should install the appropriate versions tailored to your environment."
],
"metadata": {
"collapsed": false
},
"id": "a92d2ce26df7ac4c"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: hippo-api==1.1.0.rc3 in /Users/daochengzhang/miniforge3/envs/py310/lib/python3.10/site-packages (1.1.0rc3)\r\n",
"Requirement already satisfied: pyyaml>=6.0 in /Users/daochengzhang/miniforge3/envs/py310/lib/python3.10/site-packages (from hippo-api==1.1.0.rc3) (6.0.1)\r\n"
]
}
],
"source": [
"!pip install langchain tiktoken openai\n",
"!pip install hippo-api==1.1.0.rc3"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:47:54.718488Z",
"start_time": "2023-10-30T06:47:53.563129Z"
}
},
"id": "13b1d1ae153ff434"
},
{
"cell_type": "markdown",
"source": [
"Note: Python version needs to be >=3.8.\n",
"\n",
"## Best Practice\n",
"### Importing Dependency Packages"
],
"metadata": {
"collapsed": false
},
"id": "554081137df2c252"
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [],
"source": [
"from langchain.chat_models import AzureChatOpenAI, ChatOpenAI\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores.hippo import Hippo\n",
"import os"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:47:56.003409Z",
"start_time": "2023-10-30T06:47:55.998839Z"
}
},
"id": "5ff3296ce812aeb8"
},
{
"cell_type": "markdown",
"source": [
"### Loading Knowledge Documents"
],
"metadata": {
"collapsed": false
},
"id": "dad255dae8aea755"
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [],
"source": [
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR OPENAI KEY\"\n",
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:47:59.027869Z",
"start_time": "2023-10-30T06:47:59.023934Z"
}
},
"id": "f02d66a7fd653dc1"
},
{
"cell_type": "markdown",
"source": [
"### Segmenting the Knowledge Document\n",
"\n",
"Here, we use Langchain's CharacterTextSplitter for segmentation. The delimiter is a period. After segmentation, the text segment does not exceed 1000 characters, and the number of repeated characters is 0."
],
"metadata": {
"collapsed": false
},
"id": "e9b93c330f1c6160"
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [],
"source": [
"text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:48:00.279351Z",
"start_time": "2023-10-30T06:48:00.275763Z"
}
},
"id": "fe6b43175318331f"
},
{
"cell_type": "markdown",
"source": [
"### Declaring the Embedding Model\n",
"Below, we create the OpenAI or Azure embedding model using the OpenAIEmbeddings method from Langchain."
],
"metadata": {
"collapsed": false
},
"id": "eefe28c7c993ffdf"
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"# openai\n",
"embeddings = OpenAIEmbeddings()\n",
"# azure\n",
"# embeddings = OpenAIEmbeddings(\n",
"# openai_api_type=\"azure\",\n",
"# openai_api_base=\"x x x\",\n",
"# openai_api_version=\"x x x\",\n",
"# model=\"x x x\",\n",
"# deployment=\"x x x\",\n",
"# openai_api_key=\"x x x\"\n",
"# )"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:48:11.686166Z",
"start_time": "2023-10-30T06:48:11.664355Z"
}
},
"id": "8619f16b9f7355ea"
},
{
"cell_type": "markdown",
"source": [
"### Declaring Hippo Client"
],
"metadata": {
"collapsed": false
},
"id": "e60235602ed91d3c"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"HIPPO_CONNECTION = {\"host\": \"IP\", \"port\": \"PORT\"}"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:48:48.594298Z",
"start_time": "2023-10-30T06:48:48.585267Z"
}
},
"id": "c666b70dcab78129"
},
{
"cell_type": "markdown",
"source": [
"### Storing the Document"
],
"metadata": {
"collapsed": false
},
"id": "43ee6dbd765c3172"
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input...\n",
"success\n"
]
}
],
"source": [
"print(\"input...\")\n",
"# insert docs\n",
"vector_store = Hippo.from_documents(\n",
" docs,\n",
" embedding=embeddings,\n",
" table_name=\"langchain_test\",\n",
" connection_args=HIPPO_CONNECTION,\n",
")\n",
"print(\"success\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:51:12.661741Z",
"start_time": "2023-10-30T06:51:06.257156Z"
}
},
"id": "79372c869844bdc9"
},
{
"cell_type": "markdown",
"source": [
"### Conducting Knowledge-based Question and Answer\n",
"#### Creating a Large Language Question-Answering Model\n",
"Below, we create the OpenAI or Azure large language question-answering model respectively using the AzureChatOpenAI and ChatOpenAI methods from Langchain."
],
"metadata": {
"collapsed": false
},
"id": "89077cc9763d5dd0"
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [],
"source": [
"# llm = AzureChatOpenAI(\n",
"# openai_api_base=\"x x x\",\n",
"# openai_api_version=\"xxx\",\n",
"# deployment_name=\"xxx\",\n",
"# openai_api_key=\"xxx\",\n",
"# openai_api_type=\"azure\"\n",
"# )\n",
"\n",
"llm = ChatOpenAI(openai_api_key=\"YOUR OPENAI KEY\", model_name=\"gpt-3.5-turbo-16k\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:51:28.329351Z",
"start_time": "2023-10-30T06:51:28.318713Z"
}
},
"id": "c9f2c42e9884f628"
},
{
"cell_type": "markdown",
"source": [
"### Acquiring Related Knowledge Based on the Question"
],
"metadata": {
"collapsed": false
},
"id": "a4c5d73016a9db0c"
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [],
"source": [
"query = \"Please introduce COVID-19\"\n",
"# query = \"Please introduce Hippo Core Architecture\"\n",
"# query = \"What operations does the Hippo Vector Database support for vector data?\"\n",
"# query = \"Does Hippo use hardware acceleration technology? Briefly introduce hardware acceleration technology.\"\n",
"\n",
"\n",
"# Retrieve similar content from the knowledge base,fetch the top two most similar texts.\n",
"res = vector_store.similarity_search(query, 2)\n",
"content_list = [item.page_content for item in res]\n",
"text = \"\".join(content_list)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:51:33.195634Z",
"start_time": "2023-10-30T06:51:32.196493Z"
}
},
"id": "8656e80519da1f97"
},
{
"cell_type": "markdown",
"source": [
"### Constructing a Prompt Template"
],
"metadata": {
"collapsed": false
},
"id": "e5adbaaa7086d1ae"
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [],
"source": [
"prompt = f\"\"\"\n",
"Please use the content of the following [Article] to answer my question. If you don't know, please say you don't know, and the answer should be concise.\"\n",
"[Article]:{text}\n",
"Please answer this question in conjunction with the above article:{query}\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:51:35.649376Z",
"start_time": "2023-10-30T06:51:35.645763Z"
}
},
"id": "b915d3001a2741c1"
},
{
"cell_type": "markdown",
"source": [
"### Waiting for the Large Language Model to Generate an Answer"
],
"metadata": {
"collapsed": false
},
"id": "b36b6a9adbec8a82"
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"response_with_hippo:COVID-19 is a virus that has impacted every aspect of our lives for over two years. It is a highly contagious and mutates easily, requiring us to remain vigilant in combating its spread. However, due to progress made and the resilience of individuals, we are now able to move forward safely and return to more normal routines.\n",
"==========================================\n",
"response_without_hippo:COVID-19 is a contagious respiratory illness caused by the novel coronavirus SARS-CoV-2. It was first identified in December 2019 in Wuhan, China and has since spread globally, leading to a pandemic. The virus primarily spreads through respiratory droplets when an infected person coughs, sneezes, talks, or breathes, and can also spread by touching contaminated surfaces and then touching the face. COVID-19 symptoms include fever, cough, shortness of breath, fatigue, muscle or body aches, sore throat, loss of taste or smell, headache, and in severe cases, pneumonia and organ failure. While most people experience mild to moderate symptoms, it can lead to severe illness and even death, particularly among older adults and those with underlying health conditions. To combat the spread of the virus, various preventive measures have been implemented globally, including social distancing, wearing face masks, practicing good hand hygiene, and vaccination efforts.\n"
]
}
],
"source": [
"response_with_hippo = llm.predict(prompt)\n",
"print(f\"response_with_hippo:{response_with_hippo}\")\n",
"response = llm.predict(query)\n",
"print(\"==========================================\")\n",
"print(f\"response_without_hippo:{response}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-30T06:52:17.967885Z",
"start_time": "2023-10-30T06:51:37.692819Z"
}
},
"id": "58eb5d2396321001"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-10-30T06:42:42.172639Z"
}
},
"id": "b2b7ce4e1850ecf1"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,677 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
if TYPE_CHECKING:
from transwarp_hippo_api.hippo_client import HippoClient
# Default connection
DEFAULT_HIPPO_CONNECTION = {
"host": "localhost",
"port": "7788",
"username": "admin",
"password": "admin",
}
logger = logging.getLogger(__name__)
class Hippo(VectorStore):
"""`Hippo` vector store.
You need to install `hippo-api` and run Hippo.
Please visit our official website for how to run a Hippo instance:
https://www.transwarp.cn/starwarp
Args:
embedding_function (Embeddings): Function used to embed the text.
table_name (str): Which Hippo table to use. Defaults to
"test".
database_name (str): Which Hippo database to use. Defaults to
"default".
number_of_shards (int): The number of shards for the Hippo table.Defaults to
1.
number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
1.
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
index_params (Optional[dict]): Which index params to use. Defaults to
IVF_FLAT.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
primary_field (str): Name of the primary key field. Defaults to "pk".
text_field (str): Name of the text field. Defaults to "text".
vector_field (str): Name of the vector field. Defaults to "vector".
The connection args used for this class comes in the form of a dict,
here are a few of the options:
host (str): The host of Hippo instance. Default at "localhost".
port (str/int): The port of Hippo instance. Default at 7788.
user (str): Use which user to connect to Hippo instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
Example:
.. code-block:: python
from langchain.vectorstores import Hippo
from langchain.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a hippo instance on localhost
vector_store = Hippo.from_documents(
docs,
embedding=embeddings,
table_name="langchain_test",
connection_args=HIPPO_CONNECTION
)
Raises:
ValueError: If the hippo-api python package is not installed.
"""
def __init__(
self,
embedding_function: Embeddings,
table_name: str = "test",
database_name: str = "default",
number_of_shards: int = 1,
number_of_replicas: int = 1,
connection_args: Optional[Dict[str, Any]] = None,
index_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
self.number_of_shards = number_of_shards
self.number_of_replicas = number_of_replicas
self.embedding_func = embedding_function
self.table_name = table_name
self.database_name = database_name
self.index_params = index_params
# In order for a collection to be compatible,
# 'pk' should be an auto-increment primary key and string
self._primary_field = "pk"
# In order for compatibility, the text field will need to be called "text"
self._text_field = "text"
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = "vector"
self.fields: List[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_HIPPO_CONNECTION
self.hc = self._create_connection_alias(connection_args)
self.col: Any = None
# If the collection exists, delete it
try:
if (
self.hc.check_table_exists(self.table_name, self.database_name)
and drop_old
):
self.hc.delete_table(self.table_name, self.database_name)
except Exception as e:
logging.error(
f"An error occurred while deleting the table " f"{self.table_name}: {e}"
)
raise
try:
if self.hc.check_table_exists(self.table_name, self.database_name):
self.col = self.hc.get_table(self.table_name, self.database_name)
except Exception as e:
logging.error(
f"An error occurred while getting the table " f"{self.table_name}: {e}"
)
raise
# Initialize the vector database
self._get_env()
def _create_connection_alias(self, connection_args: dict) -> HippoClient:
"""Create the connection to the Hippo server."""
# Grab the connection arguments that are used for checking existing connection
try:
from transwarp_hippo_api.hippo_client import HippoClient
except ImportError as e:
raise ImportError(
"Unable to import transwarp_hipp_api, please install with "
"`pip install hippo-api`."
) from e
host: str = connection_args.get("host", None)
port: int = connection_args.get("port", None)
username: str = connection_args.get("username", "shiva")
password: str = connection_args.get("password", "shiva")
# Order of use is host/port, uri, address
if host is not None and port is not None:
if "," in host:
hosts = host.split(",")
given_address = ",".join([f"{h}:{port}" for h in hosts])
else:
given_address = str(host) + ":" + str(port)
else:
raise ValueError("Missing standard address type for reuse attempt")
try:
logger.info(f"create HippoClient[{given_address}]")
return HippoClient([given_address], username=username, pwd=password)
except Exception as e:
logger.error("Failed to create new connection")
raise e
def _get_env(
self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
) -> None:
logger.info("init ...")
if embeddings is not None:
logger.info("create collection")
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
def _create_collection(
self, embeddings: list, metadatas: Optional[List[dict]] = None
) -> None:
from transwarp_hippo_api.hippo_client import HippoField
from transwarp_hippo_api.hippo_type import HippoType
# Determine embedding dim
dim = len(embeddings[0])
logger.debug(f"[_create_collection] dim: {dim}")
fields = []
# Create the primary key field
fields.append(HippoField(self._primary_field, True, HippoType.STRING))
# Create the text field
fields.append(HippoField(self._text_field, False, HippoType.STRING))
# Create the vector field, supports binary or float vectors
# to The binary vector type is to be developed.
fields.append(
HippoField(
self._vector_field,
False,
HippoType.FLOAT_VECTOR,
type_params={"dimension": dim},
)
)
# to In Hippo,there is no method similar to the infer_type_data
# types, so currently all non-vector data is converted to string type.
if metadatas:
# # Create FieldSchema for each entry in metadata.
for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
if isinstance(value, list):
value_dim = len(value)
fields.append(
HippoField(
key,
False,
HippoType.FLOAT_VECTOR,
type_params={"dimension": value_dim},
)
)
else:
fields.append(HippoField(key, False, HippoType.STRING))
logger.debug(f"[_create_collection] fields: {fields}")
# Create the collection
self.hc.create_table(
name=self.table_name,
auto_id=True,
fields=fields,
database_name=self.database_name,
number_of_shards=self.number_of_shards,
number_of_replicas=self.number_of_replicas,
)
self.col = self.hc.get_table(self.table_name, self.database_name)
logger.info(
f"[_create_collection] : "
f"create table {self.table_name} in {self.database_name} successfully"
)
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from transwarp_hippo_api.hippo_client import HippoTable
if isinstance(self.col, HippoTable):
schema = self.col.schema
logger.debug(f"[_extract_fields] schema:{schema}")
for x in schema:
self.fields.append(x.name)
logger.debug(f"04 [_extract_fields] fields:{self.fields}")
# TO CAN: Translated into English, your statement would be: "Currently,
# only the field named 'vector' (the automatically created vector field)
# is checked for indexing. Indexes need to be created manually for other
# vector type columns.
def _get_index(self) -> Optional[Dict[str, Any]]:
"""Return the vector index information if it exists"""
from transwarp_hippo_api.hippo_client import HippoTable
if isinstance(self.col, HippoTable):
table_info = self.hc.get_table_info(
self.table_name, self.database_name
).get(self.table_name, {})
embedding_indexes = table_info.get("embedding_indexes", None)
if embedding_indexes is None:
return None
else:
for x in self.hc.get_table_info(self.table_name, self.database_name)[
self.table_name
]["embedding_indexes"]:
logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
if x["column"] == self._vector_field:
return x
return None
# TO Indexes can only be created for the self._vector_field field.
def _create_index(self) -> None:
"""Create a index on the collection"""
from transwarp_hippo_api.hippo_client import HippoTable
from transwarp_hippo_api.hippo_type import IndexType, MetricType
if isinstance(self.col, HippoTable) and self._get_index() is None:
if self._get_index() is None:
if self.index_params is None:
self.index_params = {
"index_name": "langchain_auto_create",
"metric_type": MetricType.L2,
"index_type": IndexType.IVF_FLAT,
"nlist": 10,
}
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params["nlist"],
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
logger.info("create index successfully")
else:
index_dict = {
"IVF_FLAT": IndexType.IVF_FLAT,
"FLAT": IndexType.FLAT,
"IVF_SQ": IndexType.IVF_SQ,
"IVF_PQ": IndexType.IVF_PQ,
"HNSW": IndexType.HNSW,
}
metric_dict = {
"ip": MetricType.IP,
"IP": MetricType.IP,
"l2": MetricType.L2,
"L2": MetricType.L2,
}
self.index_params["metric_type"] = metric_dict[
self.index_params["metric_type"]
]
if self.index_params["index_type"] == "FLAT":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif (
self.index_params["index_type"] == "IVF_FLAT"
or self.index_params["index_type"] == "IVF_SQ"
):
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params.get("nlist", 10),
nprobe=self.index_params.get("nprobe", 10),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif self.index_params["index_type"] == "IVF_PQ":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
nlist=self.index_params.get("nlist", 10),
nprobe=self.index_params.get("nprobe", 10),
nbits=self.index_params.get("nbits", 8),
m=self.index_params.get("m"),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
elif self.index_params["index_type"] == "HNSW":
self.index_params["index_type"] = index_dict[
self.index_params["index_type"]
]
self.col.create_index(
self._vector_field,
self.index_params["index_name"],
self.index_params["index_type"],
self.index_params["metric_type"],
M=self.index_params.get("M"),
ef_construction=self.index_params.get("ef_construction"),
ef_search=self.index_params.get("ef_search"),
)
logger.debug(
self.col.activate_index(self.index_params["index_name"])
)
else:
raise ValueError(
"Index name does not match, "
"please enter the correct index name. "
"(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""
Add text to the collection.
Args:
texts: An iterable that contains the text to be added.
metadatas: An optional list of dictionaries,
each dictionary contains the metadata associated with a text.
timeout: Optional timeout, in seconds.
batch_size: The number of texts inserted in each batch, defaults to 1000.
**kwargs: Other optional parameters.
Returns:
A list of strings, containing the unique identifiers of the inserted texts.
Note:
If the collection has not yet been created,
this method will create a new collection.
"""
from transwarp_hippo_api.hippo_client import HippoTable
if not texts or all(t == "" for t in texts):
logger.debug("Nothing to insert, skipping.")
return []
texts = list(texts)
logger.debug(f"[add_texts] texts: {texts}")
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")
# 如果还没有创建collection则创建collection
if not isinstance(self.col, HippoTable):
self._get_env(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: Dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
logger.debug(f"[add_texts] metadatas:{metadatas}")
logger.debug(f"[add_texts] fields:{self.fields}")
if metadatas is not None:
for d in metadatas:
for key, value in d.items():
if key in self.fields:
insert_dict.setdefault(key, []).append(value)
logger.debug(insert_dict[self._text_field])
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
if "pk" in self.fields:
self.fields.remove("pk")
logger.debug(f"[add_texts] total_count:{total_count}")
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
try:
res = self.col.insert_rows(insert_list)
logger.info(f"05 [add_texts] insert {res}")
except Exception as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return [""]
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""
Perform a similarity search on the query string.
Args:
query (str): The text to search for.
k (int, optional): The number of results to return. Default is 4.
param (dict, optional): Specifies the search parameters for the index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): Time to wait before a timeout error.
Defaults to None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[Document]: The document results of the search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Performs a search on the query string and returns results with scores.
Args:
query (str): The text being searched.
k (int, optional): The number of results to return.
Default is 4.
param (dict): Specifies the search parameters for the index.
Default is None.
expr (str, optional): Filtering expression. Default is None.
timeout (int, optional): The waiting time before a timeout error.
Default is None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
ret = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return ret
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Performs a search on the query string and returns results with scores.
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The number of results to return.
Default is 4.
param (dict): Specifies the search parameters for the index.
Default is None.
expr (str, optional): Filtering expression. Default is None.
timeout (int, optional): The waiting time before a timeout error.
Default is None.
kwargs: Keyword arguments for Collection.search().
Returns:
List[Tuple[Document, float]]: Resulting documents and scores.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# if param is None:
# param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
logger.debug(f"search_field:{self._vector_field}")
logger.debug(f"vectors:{[embedding]}")
logger.debug(f"output_fields:{output_fields}")
logger.debug(f"topk:{k}")
logger.debug(f"dsl:{expr}")
res = self.col.query(
search_field=self._vector_field,
vectors=[embedding],
output_fields=output_fields,
topk=k,
dsl=expr,
)
# Organize results.
logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
score_col = self._text_field + "%scores"
ret = []
count = 0
for items in zip(*[res[0][field] for field in output_fields]):
meta = {field: value for field, value in zip(output_fields, items)}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
logger.debug(
f"[similarity_search_with_score_by_vector] "
f"res[0][score_col]:{res[0][score_col]}"
)
score = res[0][score_col][count]
count += 1
ret.append((doc, score))
return ret
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
table_name: str = "test",
database_name: str = "default",
connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
index_params: Optional[Dict[Any, Any]] = None,
search_params: Optional[Dict[str, Any]] = None,
drop_old: bool = False,
**kwargs: Any,
) -> "Hippo":
"""
Creates an instance of the VST class from the given texts.
Args:
texts (List[str]): List of texts to be added.
embedding (Embeddings): Embedding model for the texts.
metadatas (List[dict], optional):
List of metadata dictionaries for each text.Defaults to None.
table_name (str): Name of the table. Defaults to "test".
database_name (str): Name of the database. Defaults to "default".
connection_args (dict[str, Any]): Connection parameters.
Defaults to DEFAULT_HIPPO_CONNECTION.
index_params (dict): Indexing parameters. Defaults to None.
search_params (dict): Search parameters. Defaults to an empty dictionary.
drop_old (bool): Whether to drop the old collection. Defaults to False.
kwargs: Other arguments.
Returns:
Hippo: An instance of the VST class.
"""
if search_params is None:
search_params = {}
logger.info("00 [from_texts] init the class of Hippo")
vector_db = cls(
embedding_function=embedding,
table_name=table_name,
database_name=database_name,
connection_args=connection_args,
index_params=index_params,
drop_old=drop_old,
**kwargs,
)
logger.debug(f"[from_texts] texts:{texts}")
logger.debug(f"[from_texts] metadatas:{metadatas}")
vector_db.add_texts(texts=texts, metadatas=metadatas)
return vector_db

@ -0,0 +1,63 @@
"""Test Hippo functionality."""
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.vectorstores.hippo import Hippo
from tests.integration_tests.vectorstores.fake_embeddings import (
FakeEmbeddings,
fake_texts,
)
def _hippo_from_texts(
metadatas: Optional[List[dict]] = None, drop: bool = True
) -> Hippo:
return Hippo.from_texts(
fake_texts,
FakeEmbeddings(),
metadatas=metadatas,
table_name="langchain_test",
connection_args={"host": "127.0.0.1", "port": 7788},
drop_old=drop,
)
def test_hippo_add_extra() -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _hippo_from_texts(metadatas=metadatas)
docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=1)
print(output)
assert len(output) == 1
def test_hippo() -> None:
docsearch = _hippo_from_texts()
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_hippo_with_score() -> None:
"""Test end to end construction and search with scores and IDs."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _hippo_from_texts(metadatas=metadatas)
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Document(page_content="baz", metadata={"page": "2"}),
]
assert scores[0] < scores[1] < scores[2]
# if __name__ == "__main__":
# test_hippo()
# test_hippo_with_score()
# test_hippo_with_score()
Loading…
Cancel
Save