mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
201 lines
6.6 KiB
Plaintext
201 lines
6.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "245065c6",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Vector SQL Retriever with MyScale\n",
|
|
"\n",
|
|
">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0246c5bf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7585d2c3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from os import environ\n",
|
|
"import getpass\n",
|
|
"from typing import Dict, Any\n",
|
|
"from langchain.llms import OpenAI\nfrom langchain.utilities import SQLDatabase\nfrom langchain.chains import LLMChain\n",
|
|
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
|
|
"from sqlalchemy import create_engine, Column, MetaData\n",
|
|
"from langchain.prompts import PromptTemplate\n",
|
|
"\n",
|
|
"\n",
|
|
"from sqlalchemy import create_engine\n",
|
|
"\n",
|
|
"MYSCALE_HOST = \"msc-1decbcc9.us-east-1.aws.staging.myscale.cloud\"\n",
|
|
"MYSCALE_PORT = 443\n",
|
|
"MYSCALE_USER = \"chatdata\"\n",
|
|
"MYSCALE_PASSWORD = \"myscale_rocks\"\n",
|
|
"OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n",
|
|
"\n",
|
|
"engine = create_engine(\n",
|
|
" f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n",
|
|
")\n",
|
|
"metadata = MetaData(bind=engine)\n",
|
|
"environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e08d9ddc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import HuggingFaceInstructEmbeddings\n",
|
|
"from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n",
|
|
"\n",
|
|
"output_parser = VectorSQLOutputParser.from_embeddings(\n",
|
|
" model=HuggingFaceInstructEmbeddings(\n",
|
|
" model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "84b705b2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from langchain.llms import OpenAI\n",
|
|
"from langchain.callbacks import StdOutCallbackHandler\n",
|
|
"\n",
|
|
"from langchain.utilities.sql_database import SQLDatabase\n",
|
|
"from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
|
|
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
|
|
"\n",
|
|
"chain = VectorSQLDatabaseChain(\n",
|
|
" llm_chain=LLMChain(\n",
|
|
" llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
|
|
" prompt=MYSCALE_PROMPT,\n",
|
|
" ),\n",
|
|
" top_k=10,\n",
|
|
" return_direct=True,\n",
|
|
" sql_cmd_parser=output_parser,\n",
|
|
" database=SQLDatabase(engine, None, metadata),\n",
|
|
")\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"pd.DataFrame(\n",
|
|
" chain.run(\n",
|
|
" \"Please give me 10 papers to ask what is PageRank?\",\n",
|
|
" callbacks=[StdOutCallbackHandler()],\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6c09cda0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## SQL Database as Retriever"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "734d7ff5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n",
|
|
"\n",
|
|
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
|
|
"from langchain_experimental.retrievers.vector_sql_database \\\n",
|
|
" import VectorSQLDatabaseChainRetriever\n",
|
|
"from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
|
|
"from langchain_experimental.sql.vector_sql import VectorSQLRetrieveAllOutputParser\n",
|
|
"\n",
|
|
"output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n",
|
|
" output_parser.model\n",
|
|
")\n",
|
|
"\n",
|
|
"chain = VectorSQLDatabaseChain.from_llm(\n",
|
|
" llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
|
|
" prompt=MYSCALE_PROMPT,\n",
|
|
" top_k=10,\n",
|
|
" return_direct=True,\n",
|
|
" db=SQLDatabase(engine, None, metadata),\n",
|
|
" sql_cmd_parser=output_parser_retrieve_all,\n",
|
|
" native_format=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"# You need all those keys to get docs\n",
|
|
"retriever = VectorSQLDatabaseChainRetriever(sql_db_chain=chain, page_content_key=\"abstract\")\n",
|
|
"\n",
|
|
"document_with_metadata_prompt = PromptTemplate(\n",
|
|
" input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"],\n",
|
|
" template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\",\n",
|
|
")\n",
|
|
"\n",
|
|
"chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
|
|
" ChatOpenAI(\n",
|
|
" model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6\n",
|
|
" ),\n",
|
|
" retriever=retriever,\n",
|
|
" chain_type=\"stuff\",\n",
|
|
" chain_type_kwargs={\n",
|
|
" \"document_prompt\": document_with_metadata_prompt,\n",
|
|
" },\n",
|
|
" return_source_documents=True,\n",
|
|
")\n",
|
|
"ans = chain(\"Please give me 10 papers to ask what is PageRank?\",\n",
|
|
" callbacks=[StdOutCallbackHandler()])\n",
|
|
"print(ans[\"answer\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4948ff25",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|