{ "cells": [ { "cell_type": "markdown", "id": "245065c6", "metadata": {}, "source": [ "# Vector SQL Retriever with MyScale\n", "\n", ">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application." ] }, { "cell_type": "code", "execution_count": null, "id": "0246c5bf", "metadata": {}, "outputs": [], "source": [ "!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental" ] }, { "cell_type": "code", "execution_count": null, "id": "7585d2c3", "metadata": {}, "outputs": [], "source": [ "import getpass\n", "from os import environ\n", "\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", "from langchain_community.utilities import SQLDatabase\n", "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n", "from langchain_openai import OpenAI\n", "from sqlalchemy import MetaData, create_engine\n", "\n", "MYSCALE_HOST = \"msc-4a9e710a.us-east-1.aws.staging.myscale.cloud\"\n", "MYSCALE_PORT = 443\n", "MYSCALE_USER = \"chatdata\"\n", "MYSCALE_PASSWORD = \"myscale_rocks\"\n", "OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n", "\n", "engine = create_engine(\n", " f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n", ")\n", "metadata = MetaData(bind=engine)\n", "environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" ] }, { "cell_type": "code", "execution_count": null, "id": "e08d9ddc", "metadata": {}, "outputs": [], "source": [ "from langchain_community.embeddings import HuggingFaceInstructEmbeddings\n", "from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n", "\n", "output_parser = VectorSQLOutputParser.from_embeddings(\n", " model=HuggingFaceInstructEmbeddings(\n", " model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "84b705b2", "metadata": {}, "outputs": [], "source": [ "from langchain.callbacks import StdOutCallbackHandler\n", "from langchain_community.utilities.sql_database import SQLDatabase\n", "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n", "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n", "from langchain_openai import OpenAI\n", "\n", "chain = VectorSQLDatabaseChain(\n", " llm_chain=LLMChain(\n", " llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n", " prompt=MYSCALE_PROMPT,\n", " ),\n", " top_k=10,\n", " return_direct=True,\n", " sql_cmd_parser=output_parser,\n", " database=SQLDatabase(engine, None, metadata),\n", ")\n", "\n", "import pandas as pd\n", "\n", "pd.DataFrame(\n", " chain.run(\n", " \"Please give me 10 papers to ask what is PageRank?\",\n", " callbacks=[StdOutCallbackHandler()],\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "6c09cda0", "metadata": {}, "source": [ "## SQL Database as Retriever" ] }, { "cell_type": "code", "execution_count": null, "id": "734d7ff5", "metadata": {}, "outputs": [], "source": [ "from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n", "from langchain_experimental.retrievers.vector_sql_database import (\n", " VectorSQLDatabaseChainRetriever,\n", ")\n", "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n", "from langchain_experimental.sql.vector_sql import (\n", " VectorSQLDatabaseChain,\n", " VectorSQLRetrieveAllOutputParser,\n", ")\n", "from langchain_openai import ChatOpenAI\n", "\n", "output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n", " output_parser.model\n", ")\n", "\n", "chain = VectorSQLDatabaseChain.from_llm(\n", " llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n", " prompt=MYSCALE_PROMPT,\n", " top_k=10,\n", " return_direct=True,\n", " db=SQLDatabase(engine, None, metadata),\n", " sql_cmd_parser=output_parser_retrieve_all,\n", " native_format=True,\n", ")\n", "\n", "# You need all those keys to get docs\n", "retriever = VectorSQLDatabaseChainRetriever(\n", " sql_db_chain=chain, page_content_key=\"abstract\"\n", ")\n", "\n", "document_with_metadata_prompt = PromptTemplate(\n", " input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"],\n", " template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\",\n", ")\n", "\n", "chain = RetrievalQAWithSourcesChain.from_chain_type(\n", " ChatOpenAI(\n", " model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6\n", " ),\n", " retriever=retriever,\n", " chain_type=\"stuff\",\n", " chain_type_kwargs={\n", " \"document_prompt\": document_with_metadata_prompt,\n", " },\n", " return_source_documents=True,\n", ")\n", "ans = chain(\n", " \"Please give me 10 papers to ask what is PageRank?\",\n", " callbacks=[StdOutCallbackHandler()],\n", ")\n", "print(ans[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "4948ff25", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 5 }