mirror of https://github.com/hwchase17/langchain
community[minor]: VectorStore integration for SAP HANA Cloud Vector Engine (#16514)
- **Description:** This PR adds a VectorStore integration for SAP HANA Cloud Vector Engine, which is an upcoming feature in the SAP HANA Cloud database (https://blogs.sap.com/2023/11/02/sap-hana-clouds-vector-engine-announcement/). - **Issue:** N/A - **Dependencies:** [SAP HANA Python Client](https://pypi.org/project/hdbcli/) - **Twitter handle:** @sapopensource Implementation of the integration: `libs/community/langchain_community/vectorstores/hanavector.py` Unit tests: `libs/community/tests/unit_tests/vectorstores/test_hanavector.py` Integration tests: `libs/community/tests/integration_tests/vectorstores/test_hanavector.py` Example notebook: `docs/docs/integrations/vectorstores/hanavector.ipynb` Access credentials for execution of the integration tests can be provided to the maintainers. --------- Co-authored-by: sascha <sascha.stoll@sap.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/16530/head
parent
1113700b09
commit
04651f0248
@ -0,0 +1,703 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SAP HANA Cloud Vector Engine\n",
|
||||
"\n",
|
||||
">SAP HANA Cloud Vector Engine is a vector store fully integrated into the SAP HANA Cloud database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Installation of the HANA database driver."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Pip install necessary package\n",
|
||||
"%pip install --upgrade --quiet hdbcli"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To use `OpenAIEmbeddings` so we use the OpenAI API Key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-09-09T08:02:16.802456Z",
|
||||
"start_time": "2023-09-09T08:02:07.065604Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"# Use OPENAI_API_KEY env variable\n",
|
||||
"# os.environ[\"OPENAI_API_KEY\"] = \"Your OpenAI API key\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load the sample document \"state_of_the_union.txt\" and create chunks from it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-09-09T08:02:25.452472Z",
|
||||
"start_time": "2023-09-09T08:02:25.441563Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain_community.document_loaders import TextLoader\n",
|
||||
"from langchain_community.vectorstores.hanavector import HanaDB\n",
|
||||
"from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"text_documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
|
||||
"text_chunks = text_splitter.split_documents(text_documents)\n",
|
||||
"print(f\"Number of document chunks: {len(text_chunks)}\")\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a database connection to a HANA Cloud instance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-09-09T08:02:28.174088Z",
|
||||
"start_time": "2023-09-09T08:02:28.162698Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from hdbcli import dbapi\n",
|
||||
"\n",
|
||||
"# Use connection settings from the environment\n",
|
||||
"connection = dbapi.connect(\n",
|
||||
" address=os.environ.get(\"HANA_DB_ADDRESS\"),\n",
|
||||
" port=os.environ.get(\"HANA_DB_PORT\"),\n",
|
||||
" user=os.environ.get(\"HANA_DB_USER\"),\n",
|
||||
" password=os.environ.get(\"HANA_DB_PASSWORD\"),\n",
|
||||
" autocommit=True,\n",
|
||||
" sslValidateCertificate=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a LangChain VectorStore interface for the HANA database and specify the table (collection) to use for accessing the vector embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-09-09T08:04:16.696625Z",
|
||||
"start_time": "2023-09-09T08:02:31.817790Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = HanaDB(\n",
|
||||
" embedding=embeddings, connection=connection, table_name=\"STATE_OF_THE_UNION\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Add the loaded document chunks to the table. For this example, we delete any previous content from the table which might exist from previous runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Delete already existing documents from the table\n",
|
||||
"db.delete(filter={})\n",
|
||||
"\n",
|
||||
"# add the loaded document chunks\n",
|
||||
"db.add_documents(text_chunks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Perform a query to get the two best matching document chunks from the ones that we added in the previous step.\n",
|
||||
"By default \"Cosine Similarity\" is used for the search."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query, k=2)\n",
|
||||
"\n",
|
||||
"for doc in docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Query the same content with \"Euclidian Distance\". The results shoud be the same as with \"Cosine Similarity\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.vectorstores.utils import DistanceStrategy\n",
|
||||
"\n",
|
||||
"db = HanaDB(\n",
|
||||
" embedding=embeddings,\n",
|
||||
" connection=connection,\n",
|
||||
" distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,\n",
|
||||
" table_name=\"STATE_OF_THE_UNION\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query, k=2)\n",
|
||||
"for doc in docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Maximal Marginal Relevance Search (MMR)\n",
|
||||
"\n",
|
||||
"Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. First 20 (fetch_k) items will be retrieved from the DB. The MMR algorithm will then find the best 2 (k) matches."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-09-09T08:05:23.276819Z",
|
||||
"start_time": "2023-09-09T08:05:21.972256Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = db.max_marginal_relevance_search(query, k=2, fetch_k=20)\n",
|
||||
"for doc in docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Vectorstore Operations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = HanaDB(\n",
|
||||
" connection=connection, embedding=embeddings, table_name=\"LANGCHAIN_DEMO_BASIC\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Delete already existing documents from the table\n",
|
||||
"db.delete(filter={})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can add simple text documents to the existing table."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [Document(page_content=\"Some text\"), Document(page_content=\"Other docs\")]\n",
|
||||
"db.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Add documents with metadata."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"foo\",\n",
|
||||
" metadata={\"start\": 100, \"end\": 150, \"doc_name\": \"foo.txt\", \"quality\": \"bad\"},\n",
|
||||
" ),\n",
|
||||
" Document(\n",
|
||||
" page_content=\"bar\",\n",
|
||||
" metadata={\"start\": 200, \"end\": 250, \"doc_name\": \"bar.txt\", \"quality\": \"good\"},\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"db.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Query documents with specific metadata."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = db.similarity_search(\"foobar\", k=2, filter={\"quality\": \"bad\"})\n",
|
||||
"# With filtering on \"quality\"==\"bad\", only one document should be returned\n",
|
||||
"for doc in docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)\n",
|
||||
" print(doc.metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Delete documents with specific metadata."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db.delete(filter={\"quality\": \"bad\"})\n",
|
||||
"\n",
|
||||
"# Now the similarity search with the same filter will return no results\n",
|
||||
"docs = db.similarity_search(\"foobar\", k=2, filter={\"quality\": \"bad\"})\n",
|
||||
"print(len(docs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using a VectorStore as a retriever in chains for retrieval augmented generation (RAG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"# Access the vector DB with a new table\n",
|
||||
"db = HanaDB(\n",
|
||||
" connection=connection,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" table_name=\"LANGCHAIN_DEMO_RETRIEVAL_CHAIN\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Delete already existing entries from the table\n",
|
||||
"db.delete(filter={})\n",
|
||||
"\n",
|
||||
"# add the loaded document chunks from the \"State Of The Union\" file\n",
|
||||
"db.add_documents(text_chunks)\n",
|
||||
"\n",
|
||||
"# Create a retriever instance of the vector store\n",
|
||||
"retriever = db.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define the prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"prompt_template = \"\"\"\n",
|
||||
"You are an expert in state of the union topics. You are provided multiple context items that are related to the prompt you have to answer.\n",
|
||||
"Use the following pieces of context to answer the question at the end.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{context}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
||||
")\n",
|
||||
"chain_type_kwargs = {\"prompt\": PROMPT}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create the ConversationalRetrievalChain, which handles the chat history and the retrieval of similar document chunks to be added to the prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\")\n",
|
||||
"memory = ConversationBufferMemory(\n",
|
||||
" memory_key=\"chat_history\", output_key=\"answer\", return_messages=True\n",
|
||||
")\n",
|
||||
"qa_chain = ConversationalRetrievalChain.from_llm(\n",
|
||||
" llm,\n",
|
||||
" db.as_retriever(search_kwargs={\"k\": 5}),\n",
|
||||
" return_source_documents=True,\n",
|
||||
" memory=memory,\n",
|
||||
" verbose=False,\n",
|
||||
" combine_docs_chain_kwargs={\"prompt\": PROMPT},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ask the first question (and verify how many text chunks have been used)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What about Mexico and Guatemala?\"\n",
|
||||
"\n",
|
||||
"result = qa_chain.invoke({\"question\": question})\n",
|
||||
"print(\"Answer from LLM:\")\n",
|
||||
"print(\"================\")\n",
|
||||
"print(result[\"answer\"])\n",
|
||||
"\n",
|
||||
"source_docs = result[\"source_documents\"]\n",
|
||||
"print(\"================\")\n",
|
||||
"print(f\"Number of used source document chunks: {len(source_docs)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Examine the used chunks of the chain in detail. Check if the best ranked chunk contains info about \"Mexico and Guatemala\" as mentioned in the question."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for doc in source_docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)\n",
|
||||
" print(doc.metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ask another question on the same conversational chain. The answer should relate to the previous answer given."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What about other countries?\"\n",
|
||||
"\n",
|
||||
"result = qa_chain.invoke({\"question\": question})\n",
|
||||
"print(\"Answer from LLM:\")\n",
|
||||
"print(\"================\")\n",
|
||||
"print(result[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Standard tables vs. \"custom\" tables with vector data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As default behaviour, the table for the embeddings is created with 3 columns\n",
|
||||
"* A column \"VEC_TEXT\", which contains the text of the Document\n",
|
||||
"* A column \"VEC_METADATA\", which contains the metadata of the Document\n",
|
||||
"* A column \"VEC_VECTOR\", which contains the embeddings-vector of the document's text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Access the vector DB with a new table\n",
|
||||
"db = HanaDB(\n",
|
||||
" connection=connection, embedding=embeddings, table_name=\"LANGCHAIN_DEMO_NEW_TABLE\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Delete already existing entries from the table\n",
|
||||
"db.delete(filter={})\n",
|
||||
"\n",
|
||||
"# Add a simple document with some metadata\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"A simple document\",\n",
|
||||
" metadata={\"start\": 100, \"end\": 150, \"doc_name\": \"simple.txt\"},\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"db.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Show the columns in table \"LANGCHAIN_DEMO_NEW_TABLE\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cur = connection.cursor()\n",
|
||||
"cur.execute(\n",
|
||||
" \"SELECT COLUMN_NAME, DATA_TYPE_NAME FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = CURRENT_SCHEMA AND TABLE_NAME = 'LANGCHAIN_DEMO_NEW_TABLE'\"\n",
|
||||
")\n",
|
||||
"rows = cur.fetchall()\n",
|
||||
"for row in rows:\n",
|
||||
" print(row)\n",
|
||||
"cur.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Show the value of the inserted document in the three columns "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cur = connection.cursor()\n",
|
||||
"cur.execute(\n",
|
||||
" \"SELECT VEC_TEXT, VEC_META, TO_NVARCHAR(VEC_VECTOR) FROM LANGCHAIN_DEMO_NEW_TABLE LIMIT 1\"\n",
|
||||
")\n",
|
||||
"rows = cur.fetchall()\n",
|
||||
"print(rows[0][0]) # The text\n",
|
||||
"print(rows[0][1]) # The metadata\n",
|
||||
"print(rows[0][2]) # The vector\n",
|
||||
"cur.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Custom tables must have at least three columns that match the semantics of a standard table\n",
|
||||
"* A column with type \"NCLOB\" or \"NVARCHAR\" for the text/context of the embeddings\n",
|
||||
"* A column with type \"NCLOB\" or \"NVARCHAR\" for the metadata \n",
|
||||
"* A column with type REAL_VECTOR for the embedding vector\n",
|
||||
"\n",
|
||||
"The table can contain additional columns. When new Documents are inserted to the table, these addtional columns must allow NULL values."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a new table \"MY_OWN_TABLE\" with three \"standard\" columns and one additional column\n",
|
||||
"my_own_table_name = \"MY_OWN_TABLE\"\n",
|
||||
"cur = connection.cursor()\n",
|
||||
"cur.execute(\n",
|
||||
" (\n",
|
||||
" f\"CREATE TABLE {my_own_table_name} (\"\n",
|
||||
" \"SOME_OTHER_COLUMN NVARCHAR(42), \"\n",
|
||||
" \"MY_TEXT NVARCHAR(2048), \"\n",
|
||||
" \"MY_METADATA NVARCHAR(1024), \"\n",
|
||||
" \"MY_VECTOR REAL_VECTOR )\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create a HanaDB instance with the own table\n",
|
||||
"db = HanaDB(\n",
|
||||
" connection=connection,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" table_name=my_own_table_name,\n",
|
||||
" content_column=\"MY_TEXT\",\n",
|
||||
" metadata_column=\"MY_METADATA\",\n",
|
||||
" vector_column=\"MY_VECTOR\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Add a simple document with some metadata\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"Some other text\",\n",
|
||||
" metadata={\"start\": 400, \"end\": 450, \"doc_name\": \"other.txt\"},\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"db.add_documents(docs)\n",
|
||||
"\n",
|
||||
"# Check if data has been inserted into our own table\n",
|
||||
"cur.execute(f\"SELECT * FROM {my_own_table_name} LIMIT 1\")\n",
|
||||
"rows = cur.fetchall()\n",
|
||||
"print(rows[0][0]) # Value of column \"SOME_OTHER_DATA\". Should be NULL/None\n",
|
||||
"print(rows[0][1]) # The text\n",
|
||||
"print(rows[0][2]) # The metadata\n",
|
||||
"print(rows[0][3]) # The vector\n",
|
||||
"\n",
|
||||
"cur.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Add another document and perform a similarity search on the custom table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"Some more text\",\n",
|
||||
" metadata={\"start\": 800, \"end\": 950, \"doc_name\": \"more.txt\"},\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"db.add_documents(docs)\n",
|
||||
"\n",
|
||||
"query = \"What's up?\"\n",
|
||||
"docs = db.similarity_search(query, k=2)\n",
|
||||
"for doc in docs:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(doc.page_content)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,575 @@
|
||||
"""SAP HANA Cloud Vector Engine"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
maximal_marginal_relevance,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hdbcli import dbapi
|
||||
|
||||
HANA_DISTANCE_FUNCTION: dict = {
|
||||
DistanceStrategy.COSINE: ("COSINE_SIMILARITY", "DESC"),
|
||||
DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"),
|
||||
}
|
||||
|
||||
default_distance_strategy = DistanceStrategy.COSINE
|
||||
default_table_name: str = "EMBEDDINGS"
|
||||
default_content_column: str = "VEC_TEXT"
|
||||
default_metadata_column: str = "VEC_META"
|
||||
default_vector_column: str = "VEC_VECTOR"
|
||||
default_vector_column_length: int = -1 # -1 means dynamic length
|
||||
|
||||
|
||||
class HanaDB(VectorStore):
|
||||
"""SAP HANA Cloud Vector Engine
|
||||
|
||||
The prerequisite for using this class is the installation of the ``hdbcli``
|
||||
Python package.
|
||||
|
||||
The HanaDB vectorstore can be created by providing an embedding function and
|
||||
an existing database connection. Optionally, the names of the table and the
|
||||
columns to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: dbapi.Connection,
|
||||
embedding: Embeddings,
|
||||
distance_strategy: DistanceStrategy = default_distance_strategy,
|
||||
table_name: str = default_table_name,
|
||||
content_column: str = default_content_column,
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
):
|
||||
# Check if the hdbcli package is installed
|
||||
if importlib.util.find_spec("hdbcli") is None:
|
||||
raise ImportError(
|
||||
"Could not import hdbcli python package. "
|
||||
"Please install it with `pip install hdbcli`."
|
||||
)
|
||||
|
||||
valid_distance = False
|
||||
for key in HANA_DISTANCE_FUNCTION.keys():
|
||||
if key is distance_strategy:
|
||||
valid_distance = True
|
||||
if not valid_distance:
|
||||
raise ValueError(
|
||||
"Unsupported distance_strategy: {}".format(distance_strategy)
|
||||
)
|
||||
|
||||
self.connection = connection
|
||||
self.embedding = embedding
|
||||
self.distance_strategy = distance_strategy
|
||||
self.table_name = HanaDB._sanitize_name(table_name)
|
||||
self.content_column = HanaDB._sanitize_name(content_column)
|
||||
self.metadata_column = HanaDB._sanitize_name(metadata_column)
|
||||
self.vector_column = HanaDB._sanitize_name(vector_column)
|
||||
self.vector_column_length = HanaDB._sanitize_int(vector_column_length)
|
||||
|
||||
# Check if the table exists, and eventually create it
|
||||
if not self._table_exists(self.table_name):
|
||||
sql_str = (
|
||||
f"CREATE TABLE {self.table_name}("
|
||||
f"{self.content_column} NCLOB, "
|
||||
f"{self.metadata_column} NCLOB, "
|
||||
f"{self.vector_column} REAL_VECTOR "
|
||||
)
|
||||
if self.vector_column_length == -1:
|
||||
sql_str += ");"
|
||||
else:
|
||||
sql_str += f"({self.vector_column_length}));"
|
||||
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
# Check if the needed columns exist and have the correct type
|
||||
self._check_column(self.table_name, self.content_column, ["NCLOB", "NVARCHAR"])
|
||||
self._check_column(self.table_name, self.metadata_column, ["NCLOB", "NVARCHAR"])
|
||||
self._check_column(
|
||||
self.table_name,
|
||||
self.vector_column,
|
||||
["REAL_VECTOR"],
|
||||
self.vector_column_length,
|
||||
)
|
||||
|
||||
def _table_exists(self, table_name) -> bool:
|
||||
sql_str = (
|
||||
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
|
||||
" AND TABLE_NAME = ?"
|
||||
)
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str, (table_name))
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
if rows[0][0] == 1:
|
||||
return True
|
||||
finally:
|
||||
cur.close()
|
||||
return False
|
||||
|
||||
def _check_column(self, table_name, column_name, column_type, column_length=None):
|
||||
sql_str = (
|
||||
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
||||
"SCHEMA_NAME = CURRENT_SCHEMA "
|
||||
"AND TABLE_NAME = ? AND COLUMN_NAME = ?"
|
||||
)
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str, (table_name, column_name))
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
if len(rows) == 0:
|
||||
raise AttributeError(f"Column {column_name} does not exist")
|
||||
# Check data type
|
||||
if rows[0][0] not in column_type:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong type: {rows[0][0]}"
|
||||
)
|
||||
# Check length, if parameter was provided
|
||||
if column_length is not None:
|
||||
if rows[0][1] != column_length:
|
||||
raise AttributeError(
|
||||
f"Column {column_name} has the wrong length: {rows[0][1]}"
|
||||
)
|
||||
else:
|
||||
raise AttributeError(f"Column {column_name} does not exist")
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
def _sanitize_name(input_str: str) -> str:
|
||||
# Remove characters that are not alphanumeric or underscores
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
|
||||
|
||||
def _sanitize_int(input_int: any) -> int:
|
||||
value = int(str(input_int))
|
||||
if value < -1:
|
||||
raise ValueError(f"Value ({value}) must not be smaller than -1")
|
||||
return int(str(input_int))
|
||||
|
||||
def _sanitize_list_float(embedding: List[float]) -> List[float]:
|
||||
for value in embedding:
|
||||
if not isinstance(value, float):
|
||||
raise ValueError(f"Value ({value}) does not have type float")
|
||||
return embedding
|
||||
|
||||
# Compile pattern only once, for better performance
|
||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
||||
for key in metadata.keys():
|
||||
if not HanaDB._compiled_pattern.match(key):
|
||||
raise ValueError(f"Invalid metadata key {key}")
|
||||
|
||||
return metadata
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
) -> List[str]:
|
||||
"""Add more texts to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
Defaults to None.
|
||||
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
||||
embeddings. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: empty list
|
||||
"""
|
||||
# Create all embeddings of the texts beforehand to improve performance
|
||||
if embeddings is None:
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Insert data into the table
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided values by default or fallback
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = (
|
||||
embeddings[i]
|
||||
if embeddings
|
||||
else self.embedding.embed_documents([text])[0]
|
||||
)
|
||||
sql_str = (
|
||||
f"INSERT INTO {self.table_name} ({self.content_column}, "
|
||||
f"{self.metadata_column}, {self.vector_column}) "
|
||||
f"VALUES (?, ?, TO_REAL_VECTOR (?));"
|
||||
)
|
||||
cur.execute(
|
||||
sql_str,
|
||||
(
|
||||
text,
|
||||
json.dumps(HanaDB._sanitize_metadata_keys(metadata)),
|
||||
f"[{','.join(map(str, embedding))}]",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
cur.close()
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[HanaDB],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
connection: dbapi.Connection = None,
|
||||
distance_strategy: DistanceStrategy = default_distance_strategy,
|
||||
table_name: str = default_table_name,
|
||||
content_column: str = default_content_column,
|
||||
metadata_column: str = default_metadata_column,
|
||||
vector_column: str = default_vector_column,
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
):
|
||||
"""Create a HanaDB instance from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a table if it does not yet exist.
|
||||
3. Adds the documents to the table.
|
||||
This is intended to be a quick way to get started.
|
||||
"""
|
||||
|
||||
instance = cls(
|
||||
connection=connection,
|
||||
embedding=embedding,
|
||||
distance_strategy=distance_strategy,
|
||||
table_name=table_name,
|
||||
content_column=content_column,
|
||||
metadata_column=metadata_column,
|
||||
vector_column=vector_column,
|
||||
vector_column_length=vector_column_length, # -1 means dynamic length
|
||||
)
|
||||
instance.add_texts(texts, metadatas)
|
||||
return instance
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query=query, k=k, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, filter: Optional[dict] = None
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return documents and score values most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of tuples (containing a Document and a score) that are
|
||||
most similar to the query
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, filter=filter
|
||||
)
|
||||
|
||||
def similarity_search_with_score_and_vector_by_vector(
|
||||
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
|
||||
) -> List[Tuple[Document, float, List[float]]]:
|
||||
"""Return docs most similar to the given embedding.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and
|
||||
score and the document's embedding vector for each
|
||||
"""
|
||||
result = []
|
||||
k = HanaDB._sanitize_int(k)
|
||||
embedding = HanaDB._sanitize_list_float(embedding)
|
||||
distance_func_name = HANA_DISTANCE_FUNCTION[self.distance_strategy][0]
|
||||
embedding_as_str = ",".join(map(str, embedding))
|
||||
sql_str = (
|
||||
f"SELECT TOP {k}"
|
||||
f" {self.content_column}, " # row[0]
|
||||
f" {self.metadata_column}, " # row[1]
|
||||
f" TO_NVARCHAR({self.vector_column}), " # row[2]
|
||||
f" {distance_func_name}({self.vector_column}, TO_REAL_VECTOR "
|
||||
f" (ARRAY({embedding_as_str}))) AS CS " # row[3]
|
||||
f"FROM {self.table_name}"
|
||||
)
|
||||
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
|
||||
where_str, query_tuple = self._create_where_by_filter(filter)
|
||||
sql_str = sql_str + where_str
|
||||
sql_str = sql_str + order_str
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str, query_tuple)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
js = json.loads(row[1])
|
||||
doc = Document(page_content=row[0], metadata=js)
|
||||
result_vector = HanaDB._parse_float_array_from_string(row[2])
|
||||
result.append((doc, row[3], result_vector))
|
||||
finally:
|
||||
cur.close()
|
||||
return result
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to the given embedding.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
whole_result = self.similarity_search_with_score_and_vector_by_vector(
|
||||
embedding=embedding, k=k, filter=filter
|
||||
)
|
||||
return [(result_item[0], result_item[1]) for result_item in whole_result]
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def _create_where_by_filter(self, filter):
|
||||
query_tuple = []
|
||||
where_str = ""
|
||||
if filter:
|
||||
for i, key in enumerate(filter.keys()):
|
||||
if i == 0:
|
||||
where_str += " WHERE "
|
||||
else:
|
||||
where_str += " AND "
|
||||
|
||||
where_str += f" JSON_VALUE({self.metadata_column}, '$.{key}') = ?"
|
||||
|
||||
if isinstance(filter[key], bool):
|
||||
if filter[key]:
|
||||
query_tuple.append("true")
|
||||
else:
|
||||
query_tuple.append("false")
|
||||
elif isinstance(filter[key], int) or isinstance(filter[key], str):
|
||||
query_tuple.append(filter[key])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter data-type: {type(filter[key])}"
|
||||
)
|
||||
|
||||
return where_str, query_tuple
|
||||
|
||||
def delete(
|
||||
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
||||
) -> Optional[bool]:
|
||||
"""Delete entries by filter with metadata values
|
||||
|
||||
Args:
|
||||
ids: Deletion with ids is not supported! A ValueError will be raised.
|
||||
filter: A dictionary of metadata fields and values to filter by.
|
||||
An empty filter ({}) will delete all entries in the table.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True, if deletion is technically successful.
|
||||
Deletion of zero entries, due to non-matching filters is a success.
|
||||
"""
|
||||
|
||||
if ids is not None:
|
||||
raise ValueError("Deletion via ids is not supported")
|
||||
|
||||
if filter is None:
|
||||
raise ValueError("Parameter 'filter' is required when calling 'delete'")
|
||||
|
||||
where_str, query_tuple = self._create_where_by_filter(filter)
|
||||
sql_str = f"DELETE FROM {self.table_name} {where_str}"
|
||||
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str, query_tuple)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
return True
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
return await run_in_executor(None, self.delete, ids=ids, filter=filter)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: search query text.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter on metadata properties, e.g.
|
||||
{
|
||||
"str_property": "foo",
|
||||
"int_property": 123
|
||||
}
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def _parse_float_array_from_string(array_as_string: str) -> List[float]:
|
||||
array_wo_brackets = array_as_string[1:-1]
|
||||
return [float(x) for x in array_wo_brackets.split(",")]
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
whole_result = self.similarity_search_with_score_and_vector_by_vector(
|
||||
embedding=embedding, k=fetch_k, filter=filter
|
||||
)
|
||||
embeddings = [result_item[2] for result_item in whole_result]
|
||||
mmr_doc_indexes = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, lambda_mult=lambda_mult, k=k
|
||||
)
|
||||
|
||||
return [whole_result[i][0] for i in mmr_doc_indexes]
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.max_marginal_relevance_search_by_vector,
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_relevance_score_fn(distance: float) -> float:
|
||||
return distance
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
|
||||
Vectorstores should define their own selection based method of relevance.
|
||||
"""
|
||||
if self.distance_strategy == DistanceStrategy.COSINE:
|
||||
return HanaDB._cosine_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
return HanaDB._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported distance_strategy: {}".format(self.distance_strategy)
|
||||
)
|
@ -0,0 +1,891 @@
|
||||
"""Test HANA vectorstore functionality."""
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_community.vectorstores import HanaDB
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
try:
|
||||
from hdbcli import dbapi
|
||||
|
||||
hanadb_installed = True
|
||||
except ImportError:
|
||||
hanadb_installed = False
|
||||
|
||||
|
||||
class NormalizedFakeEmbeddings(ConsistentFakeEmbeddings):
|
||||
"""Fake embeddings with normalization. For testing purposes."""
|
||||
|
||||
def normalize(self, vector: List[float]) -> List[float]:
|
||||
"""Normalize vector."""
|
||||
return [float(v / np.linalg.norm(vector)) for v in vector]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.normalize(v) for v in super().embed_documents(texts)]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.normalize(super().embed_query(text))
|
||||
|
||||
|
||||
embedding = NormalizedFakeEmbeddings()
|
||||
|
||||
|
||||
class ConfigData:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.schema_name = ""
|
||||
|
||||
|
||||
test_setup = ConfigData()
|
||||
|
||||
|
||||
def generateSchemaName(cursor):
|
||||
cursor.execute(
|
||||
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
|
||||
"DUMMY;"
|
||||
)
|
||||
if cursor.has_result_set():
|
||||
rows = cursor.fetchall()
|
||||
uid = rows[0][0]
|
||||
else:
|
||||
uid = random.randint(1, 100000000)
|
||||
return f"VEC_{uid}"
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
test_setup.conn = dbapi.connect(
|
||||
address=os.environ.get("HANA_DB_ADDRESS"),
|
||||
port=os.environ.get("HANA_DB_PORT"),
|
||||
user=os.environ.get("HANA_DB_USER"),
|
||||
password=os.environ.get("HANA_DB_PASSWORD"),
|
||||
autocommit=True,
|
||||
sslValidateCertificate=False,
|
||||
)
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
test_setup.schema_name = generateSchemaName(cur)
|
||||
sql_str = f"CREATE SCHEMA {test_setup.schema_name}"
|
||||
cur.execute(sql_str)
|
||||
sql_str = f"SET SCHEMA {test_setup.schema_name}"
|
||||
cur.execute(sql_str)
|
||||
except dbapi.ProgrammingError:
|
||||
pass
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
|
||||
cur.execute(sql_str)
|
||||
except dbapi.ProgrammingError:
|
||||
pass
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def texts() -> List[str]:
|
||||
return ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadatas() -> List[str]:
|
||||
return [
|
||||
{"start": 0, "end": 100, "quality": "good", "ready": True},
|
||||
{"start": 100, "end": 200, "quality": "bad", "ready": False},
|
||||
{"start": 200, "end": 300, "quality": "ugly", "ready": True},
|
||||
]
|
||||
|
||||
|
||||
def drop_table(connection, table_name):
|
||||
try:
|
||||
cur = connection.cursor()
|
||||
sql_str = f"DROP TABLE {table_name}"
|
||||
cur.execute(sql_str)
|
||||
except dbapi.ProgrammingError:
|
||||
pass
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_non_existing_table() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
table_name = "NON_EXISTING"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectordb = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
assert vectordb._table_exists(table_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_table_with_missing_columns() -> None:
|
||||
table_name = "EXISTING_MISSING_COLS"
|
||||
try:
|
||||
drop_table(test_setup.conn, table_name)
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = f"CREATE TABLE {table_name}(WRONG_COL NVARCHAR(500));"
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
# Check if table is created
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
table_name=table_name,
|
||||
)
|
||||
exception_occured = False
|
||||
except AttributeError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_table_with_nvarchar_content(texts: List[str]) -> None:
|
||||
table_name = "EXISTING_NVARCHAR"
|
||||
content_column = "TEST_TEXT"
|
||||
metadata_column = "TEST_META"
|
||||
vector_column = "TEST_VECTOR"
|
||||
try:
|
||||
drop_table(test_setup.conn, table_name)
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = (
|
||||
f"CREATE TABLE {table_name}({content_column} NVARCHAR(2048), "
|
||||
f"{metadata_column} NVARCHAR(2048), {vector_column} REAL_VECTOR);"
|
||||
)
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
vectordb = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
table_name=table_name,
|
||||
content_column=content_column,
|
||||
metadata_column=metadata_column,
|
||||
vector_column=vector_column,
|
||||
)
|
||||
|
||||
vectordb.add_texts(texts=texts)
|
||||
|
||||
# check that embeddings have been created in the table
|
||||
number_of_texts = len(texts)
|
||||
number_of_rows = -1
|
||||
sql_str = f"SELECT COUNT(*) FROM {table_name}"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
number_of_rows = rows[0][0]
|
||||
assert number_of_rows == number_of_texts
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_table_with_wrong_typed_columns() -> None:
|
||||
table_name = "EXISTING_WRONG_TYPES"
|
||||
content_column = "DOC_TEXT"
|
||||
metadata_column = "DOC_META"
|
||||
vector_column = "DOC_VECTOR"
|
||||
try:
|
||||
drop_table(test_setup.conn, table_name)
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = (
|
||||
f"CREATE TABLE {table_name}({content_column} INTEGER, "
|
||||
f"{metadata_column} INTEGER, {vector_column} INTEGER);"
|
||||
)
|
||||
cur.execute(sql_str)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
# Check if table is created
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
table_name=table_name,
|
||||
)
|
||||
exception_occured = False
|
||||
except AttributeError as err:
|
||||
print(err)
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_non_existing_table_fixed_vector_length() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
table_name = "NON_EXISTING"
|
||||
vector_column = "MY_VECTOR"
|
||||
vector_column_length = 42
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectordb = HanaDB(
|
||||
connection=test_setup.conn,
|
||||
embedding=embedding,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
table_name=table_name,
|
||||
vector_column=vector_column,
|
||||
vector_column_length=vector_column_length,
|
||||
)
|
||||
|
||||
assert vectordb._table_exists(table_name)
|
||||
vectordb._check_column(
|
||||
table_name, vector_column, "REAL_VECTOR", vector_column_length
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_add_texts(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_ADD_TEXTS"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectordb = HanaDB(
|
||||
connection=test_setup.conn, embedding=embedding, table_name=table_name
|
||||
)
|
||||
|
||||
vectordb.add_texts(texts=texts)
|
||||
|
||||
# check that embeddings have been created in the table
|
||||
number_of_texts = len(texts)
|
||||
number_of_rows = -1
|
||||
sql_str = f"SELECT COUNT(*) FROM {table_name}"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
number_of_rows = rows[0][0]
|
||||
assert number_of_rows == number_of_texts
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_from_texts(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_FROM_TEXTS"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
# test if vectorDB is instance of HanaDB
|
||||
assert isinstance(vectorDB, HanaDB)
|
||||
|
||||
# check that embeddings have been created in the table
|
||||
number_of_texts = len(texts)
|
||||
number_of_rows = -1
|
||||
sql_str = f"SELECT COUNT(*) FROM {table_name}"
|
||||
cur = test_setup.conn.cursor()
|
||||
cur.execute(sql_str)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
number_of_rows = rows[0][0]
|
||||
assert number_of_rows == number_of_texts
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_simple(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_SEARCH_SIMPLE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
assert texts[0] == vectorDB.similarity_search(texts[0], 1)[0].page_content
|
||||
assert texts[1] != vectorDB.similarity_search(texts[0], 1)[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_by_vector_simple(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_SEARCH_SIMPLE_VECTOR"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
vector = embedding.embed_query(texts[0])
|
||||
assert texts[0] == vectorDB.similarity_search_by_vector(vector, 1)[0].page_content
|
||||
assert texts[1] != vectorDB.similarity_search_by_vector(vector, 1)[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_simple_euclidean_distance(
|
||||
texts: List[str],
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_SEARCH_EUCLIDIAN"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
)
|
||||
|
||||
assert texts[0] == vectorDB.similarity_search(texts[0], 1)[0].page_content
|
||||
assert texts[1] != vectorDB.similarity_search(texts[0], 1)[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_metadata(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_METADATA"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
|
||||
assert texts[0] == search_result[0].page_content
|
||||
assert metadatas[0]["start"] == search_result[0].metadata["start"]
|
||||
assert metadatas[0]["end"] == search_result[0].metadata["end"]
|
||||
assert texts[1] != search_result[0].page_content
|
||||
assert metadatas[1]["start"] != search_result[0].metadata["start"]
|
||||
assert metadatas[1]["end"] != search_result[0].metadata["end"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_metadata_filter(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_FILTER"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3, filter={"start": 100})
|
||||
|
||||
assert len(search_result) == 1
|
||||
assert texts[1] == search_result[0].page_content
|
||||
assert metadatas[1]["start"] == search_result[0].metadata["start"]
|
||||
assert metadatas[1]["end"] == search_result[0].metadata["end"]
|
||||
|
||||
search_result = vectorDB.similarity_search(
|
||||
texts[0], 3, filter={"start": 100, "end": 150}
|
||||
)
|
||||
assert len(search_result) == 0
|
||||
|
||||
search_result = vectorDB.similarity_search(
|
||||
texts[0], 3, filter={"start": 100, "end": 200}
|
||||
)
|
||||
assert len(search_result) == 1
|
||||
assert texts[1] == search_result[0].page_content
|
||||
assert metadatas[1]["start"] == search_result[0].metadata["start"]
|
||||
assert metadatas[1]["end"] == search_result[0].metadata["end"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_metadata_filter_string(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_FILTER_STRING"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3, filter={"quality": "bad"})
|
||||
|
||||
assert len(search_result) == 1
|
||||
assert texts[1] == search_result[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_metadata_filter_bool(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_FILTER_BOOL"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3, filter={"ready": False})
|
||||
|
||||
assert len(search_result) == 1
|
||||
assert texts[1] == search_result[0].page_content
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_metadata_filter_invalid_type(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_FILTER_INVALID_TYPE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
exception_occured = False
|
||||
try:
|
||||
vectorDB.similarity_search(texts[0], 3, filter={"wrong_type": 0.1})
|
||||
except ValueError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_score(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_SCORE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search_with_score(texts[0], 3)
|
||||
|
||||
assert search_result[0][0].page_content == texts[0]
|
||||
assert search_result[0][1] == 1.0
|
||||
assert search_result[1][1] <= search_result[0][1]
|
||||
assert search_result[2][1] <= search_result[1][1]
|
||||
assert search_result[2][1] >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_relevance_score(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_REL_SCORE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search_with_relevance_scores(texts[0], 3)
|
||||
|
||||
assert search_result[0][0].page_content == texts[0]
|
||||
assert search_result[0][1] == 1.0
|
||||
assert search_result[1][1] <= search_result[0][1]
|
||||
assert search_result[2][1] <= search_result[1][1]
|
||||
assert search_result[2][1] >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_relevance_score_with_euclidian_distance(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_REL_SCORE_EUCLIDIAN"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search_with_relevance_scores(texts[0], 3)
|
||||
|
||||
assert search_result[0][0].page_content == texts[0]
|
||||
assert search_result[0][1] == 1.0
|
||||
assert search_result[1][1] <= search_result[0][1]
|
||||
assert search_result[2][1] <= search_result[1][1]
|
||||
assert search_result[2][1] >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_similarity_search_with_score_with_euclidian_distance(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_SCORE_DISTANCE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search_with_score(texts[0], 3)
|
||||
|
||||
assert search_result[0][0].page_content == texts[0]
|
||||
assert search_result[0][1] == 0.0
|
||||
assert search_result[1][1] >= search_result[0][1]
|
||||
assert search_result[2][1] >= search_result[1][1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_delete_with_filter(texts: List[str], metadatas: List[dict]) -> None:
|
||||
table_name = "TEST_TABLE_DELETE_FILTER"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Fill table
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 3
|
||||
|
||||
# Delete one of the three entries
|
||||
assert vectorDB.delete(filter={"start": 100, "end": 200})
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
async def test_hanavector_delete_with_filter_async(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_DELETE_FILTER_ASYNC"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Fill table
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 3
|
||||
|
||||
# Delete one of the three entries
|
||||
assert await vectorDB.adelete(filter={"start": 100, "end": 200})
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_delete_all_with_empty_filter(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_DELETE_ALL"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Fill table
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 3
|
||||
|
||||
# Delete all entries
|
||||
assert vectorDB.delete(filter={})
|
||||
|
||||
search_result = vectorDB.similarity_search(texts[0], 3)
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_delete_called_wrong(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_DELETE_FILTER_WRONG"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Fill table
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
# Delete without filter parameter
|
||||
exception_occured = False
|
||||
try:
|
||||
vectorDB.delete()
|
||||
except ValueError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
# Delete with ids parameter
|
||||
exception_occured = False
|
||||
try:
|
||||
vectorDB.delete(ids=["id1", "id"], filter={"start": 100, "end": 200})
|
||||
except ValueError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_max_marginal_relevance_search(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_MAX_RELEVANCE"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.max_marginal_relevance_search(texts[0], k=2, fetch_k=20)
|
||||
|
||||
assert len(search_result) == 2
|
||||
assert search_result[0].page_content == texts[0]
|
||||
assert search_result[1].page_content != texts[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_max_marginal_relevance_search_vector(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_MAX_RELEVANCE_VECTOR"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = vectorDB.max_marginal_relevance_search_by_vector(
|
||||
embedding.embed_query(texts[0]), k=2, fetch_k=20
|
||||
)
|
||||
|
||||
assert len(search_result) == 2
|
||||
assert search_result[0].page_content == texts[0]
|
||||
assert search_result[1].page_content != texts[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
async def test_hanavector_max_marginal_relevance_search_async(texts: List[str]) -> None:
|
||||
table_name = "TEST_TABLE_MAX_RELEVANCE_ASYNC"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
vectorDB = HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
search_result = await vectorDB.amax_marginal_relevance_search(
|
||||
texts[0], k=2, fetch_k=20
|
||||
)
|
||||
|
||||
assert len(search_result) == 2
|
||||
assert search_result[0].page_content == texts[0]
|
||||
assert search_result[1].page_content != texts[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
|
||||
def test_hanavector_filter_prepared_statement_params(
|
||||
texts: List[str], metadatas: List[dict]
|
||||
) -> None:
|
||||
table_name = "TEST_TABLE_FILTER_PARAM"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
# Check if table is created
|
||||
HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
cur = test_setup.conn.cursor()
|
||||
sql_str = (
|
||||
f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.start') = '100'"
|
||||
)
|
||||
cur.execute(sql_str)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
query_value = 100
|
||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.start') = ?"
|
||||
cur.execute(sql_str, (query_value))
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
sql_str = (
|
||||
f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = 'good'"
|
||||
)
|
||||
cur.execute(sql_str)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
query_value = "good"
|
||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?"
|
||||
cur.execute(sql_str, (query_value))
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
sql_str = (
|
||||
f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = false"
|
||||
)
|
||||
cur.execute(sql_str)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
# query_value = True
|
||||
query_value = "true"
|
||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
||||
cur.execute(sql_str, (query_value))
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 2
|
||||
|
||||
# query_value = False
|
||||
query_value = "false"
|
||||
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
|
||||
cur.execute(sql_str, (query_value))
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
def test_invalid_metadata_keys(texts: List[str], metadatas: List[dict]) -> None:
|
||||
table_name = "TEST_TABLE_INVALID_METADATA"
|
||||
# Delete table if it exists
|
||||
drop_table(test_setup.conn, table_name)
|
||||
|
||||
invalid_metadatas = [
|
||||
{"sta rt": 0, "end": 100, "quality": "good", "ready": True},
|
||||
]
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=invalid_metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
except ValueError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
||||
|
||||
invalid_metadatas = [
|
||||
{"sta/nrt": 0, "end": 100, "quality": "good", "ready": True},
|
||||
]
|
||||
exception_occured = False
|
||||
try:
|
||||
HanaDB.from_texts(
|
||||
connection=test_setup.conn,
|
||||
texts=texts,
|
||||
metadatas=invalid_metadatas,
|
||||
embedding=embedding,
|
||||
table_name=table_name,
|
||||
)
|
||||
except ValueError:
|
||||
exception_occured = True
|
||||
assert exception_occured
|
@ -0,0 +1,46 @@
|
||||
"""Test HanaVector functionality."""
|
||||
|
||||
from langchain_community.vectorstores import HanaDB
|
||||
|
||||
|
||||
def test_int_sanitation_with_illegal_value() -> None:
|
||||
"""Test sanitization of int with illegal value"""
|
||||
successful = True
|
||||
try:
|
||||
HanaDB._sanitize_int("HUGO")
|
||||
successful = False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
assert successful
|
||||
|
||||
|
||||
def test_int_sanitation_with_legal_values() -> None:
|
||||
"""Test sanitization of int with legal values"""
|
||||
assert HanaDB._sanitize_int(42) == 42
|
||||
|
||||
assert HanaDB._sanitize_int("21") == 21
|
||||
|
||||
|
||||
def test_int_sanitation_with_negative_values() -> None:
|
||||
"""Test sanitization of int with legal values"""
|
||||
assert HanaDB._sanitize_int(-1) == -1
|
||||
|
||||
assert HanaDB._sanitize_int("-1") == -1
|
||||
|
||||
|
||||
def test_int_sanitation_with_illegal_negative_value() -> None:
|
||||
"""Test sanitization of int with illegal value"""
|
||||
successful = True
|
||||
try:
|
||||
HanaDB._sanitize_int(-2)
|
||||
successful = False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
assert successful
|
||||
|
||||
|
||||
def test_parse_float_array_from_string() -> None:
|
||||
array_as_string = "[0.1, 0.2, 0.3]"
|
||||
assert HanaDB._parse_float_array_from_string(array_as_string) == [0.1, 0.2, 0.3]
|
Loading…
Reference in New Issue