mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
3331865f6b
**Description**: ToolKit and Tools for accessing data in a Cassandra Database primarily for Agent integration. Initially, this includes the following tools: - `cassandra_db_schema` Gathers all schema information for the connected database or a specific schema. Critical for the agent when determining actions. - `cassandra_db_select_table_data` Selects data from a specific keyspace and table. The agent can pass paramaters for a predicate and limits on the number of returned records. - `cassandra_db_query` Expiriemental alternative to `cassandra_db_select_table_data` which takes a query string completely formed by the agent instead of parameters. May be removed in future versions. Includes unit test and two notebooks to demonstrate usage. **Dependencies**: cassio **Twitter handle**: @PatrickMcFadin --------- Co-authored-by: Phil Miesle <phil.miesle@datastax.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
558 lines
16 KiB
Plaintext
558 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup Environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Python Modules"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Install the following Python modules:\n",
|
|
"\n",
|
|
"```bash\n",
|
|
"pip install ipykernel python-dotenv cassio pandas langchain_openai langchain langchain-community langchainhub langchain_experimental openai-multi-tool-use-parallel-patch\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load the `.env` File"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Connection is via `cassio` using `auto=True` parameter, and the notebook uses OpenAI. You should create a `.env` file accordingly.\n",
|
|
"\n",
|
|
"For Casssandra, set:\n",
|
|
"```bash\n",
|
|
"CASSANDRA_CONTACT_POINTS\n",
|
|
"CASSANDRA_USERNAME\n",
|
|
"CASSANDRA_PASSWORD\n",
|
|
"CASSANDRA_KEYSPACE\n",
|
|
"```\n",
|
|
"\n",
|
|
"For Astra, set:\n",
|
|
"```bash\n",
|
|
"ASTRA_DB_APPLICATION_TOKEN\n",
|
|
"ASTRA_DB_DATABASE_ID\n",
|
|
"ASTRA_DB_KEYSPACE\n",
|
|
"```\n",
|
|
"\n",
|
|
"For example:\n",
|
|
"\n",
|
|
"```bash\n",
|
|
"# Connection to Astra:\n",
|
|
"ASTRA_DB_DATABASE_ID=a1b2c3d4-...\n",
|
|
"ASTRA_DB_APPLICATION_TOKEN=AstraCS:...\n",
|
|
"ASTRA_DB_KEYSPACE=notebooks\n",
|
|
"\n",
|
|
"# Also set \n",
|
|
"OPENAI_API_KEY=sk-....\n",
|
|
"```\n",
|
|
"\n",
|
|
"(You may also modify the below code to directly connect with `cassio`.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from dotenv import load_dotenv\n",
|
|
"\n",
|
|
"load_dotenv(override=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Connect to Cassandra"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"import cassio\n",
|
|
"\n",
|
|
"cassio.init(auto=True)\n",
|
|
"session = cassio.config.resolve_session()\n",
|
|
"if not session:\n",
|
|
" raise Exception(\n",
|
|
" \"Check environment configuration or manually configure cassio connection parameters\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"keyspace = os.environ.get(\n",
|
|
" \"ASTRA_DB_KEYSPACE\", os.environ.get(\"CASSANDRA_KEYSPACE\", None)\n",
|
|
")\n",
|
|
"if not keyspace:\n",
|
|
" raise ValueError(\"a KEYSPACE environment variable must be set\")\n",
|
|
"\n",
|
|
"session.set_keyspace(keyspace)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup Database"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This needs to be done one time only!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Download Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The dataset used is from Kaggle, the [Environmental Sensor Telemetry Data](https://www.kaggle.com/datasets/garystafford/environmental-sensor-data-132k?select=iot_telemetry_data.csv). The next cell will download and unzip the data into a Pandas dataframe. The following cell is instructions to download manually. \n",
|
|
"\n",
|
|
"The net result of this section is you should have a Pandas dataframe variable `df`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Download Automatically"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from io import BytesIO\n",
|
|
"from zipfile import ZipFile\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"import requests\n",
|
|
"\n",
|
|
"datasetURL = \"https://storage.googleapis.com/kaggle-data-sets/788816/1355729/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20240404%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240404T115828Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=2849f003b100eb9dcda8dd8535990f51244292f67e4f5fad36f14aa67f2d4297672d8fe6ff5a39f03a29cda051e33e95d36daab5892b8874dcd5a60228df0361fa26bae491dd4371f02dd20306b583a44ba85a4474376188b1f84765147d3b4f05c57345e5de883c2c29653cce1f3755cd8e645c5e952f4fb1c8a735b22f0c811f97f7bce8d0235d0d3731ca8ab4629ff381f3bae9e35fc1b181c1e69a9c7913a5e42d9d52d53e5f716467205af9c8a3cc6746fc5352e8fbc47cd7d18543626bd67996d18c2045c1e475fc136df83df352fa747f1a3bb73e6ba3985840792ec1de407c15836640ec96db111b173bf16115037d53fdfbfd8ac44145d7f9a546aa\"\n",
|
|
"\n",
|
|
"response = requests.get(datasetURL)\n",
|
|
"if response.status_code == 200:\n",
|
|
" zip_file = ZipFile(BytesIO(response.content))\n",
|
|
" csv_file_name = zip_file.namelist()[0]\n",
|
|
"else:\n",
|
|
" print(\"Failed to download the file\")\n",
|
|
"\n",
|
|
"with zip_file.open(csv_file_name) as csv_file:\n",
|
|
" df = pd.read_csv(csv_file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Download Manually"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"You can download the `.zip` file and unpack the `.csv` contained within. Comment in the next line, and adjust the path to this `.csv` file appropriately."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# df = pd.read_csv(\"/path/to/iot_telemetry_data.csv\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load Data into Cassandra"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This section assumes the existence of a dataframe `df`, the following cell validates its structure. The Download section above creates this object."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert df is not None, \"Dataframe 'df' must be set\"\n",
|
|
"expected_columns = [\n",
|
|
" \"ts\",\n",
|
|
" \"device\",\n",
|
|
" \"co\",\n",
|
|
" \"humidity\",\n",
|
|
" \"light\",\n",
|
|
" \"lpg\",\n",
|
|
" \"motion\",\n",
|
|
" \"smoke\",\n",
|
|
" \"temp\",\n",
|
|
"]\n",
|
|
"assert all(\n",
|
|
" [column in df.columns for column in expected_columns]\n",
|
|
"), \"DataFrame does not have the expected columns\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Create and load tables:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from datetime import UTC, datetime\n",
|
|
"\n",
|
|
"from cassandra.query import BatchStatement\n",
|
|
"\n",
|
|
"# Create sensors table\n",
|
|
"table_query = \"\"\"\n",
|
|
"CREATE TABLE IF NOT EXISTS iot_sensors (\n",
|
|
" device text,\n",
|
|
" conditions text,\n",
|
|
" room text,\n",
|
|
" PRIMARY KEY (device)\n",
|
|
")\n",
|
|
"WITH COMMENT = 'Environmental IoT room sensor metadata.';\n",
|
|
"\"\"\"\n",
|
|
"session.execute(table_query)\n",
|
|
"\n",
|
|
"pstmt = session.prepare(\n",
|
|
" \"\"\"\n",
|
|
"INSERT INTO iot_sensors (device, conditions, room)\n",
|
|
"VALUES (?, ?, ?)\n",
|
|
"\"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"devices = [\n",
|
|
" (\"00:0f:00:70:91:0a\", \"stable conditions, cooler and more humid\", \"room 1\"),\n",
|
|
" (\"1c:bf:ce:15:ec:4d\", \"highly variable temperature and humidity\", \"room 2\"),\n",
|
|
" (\"b8:27:eb:bf:9d:51\", \"stable conditions, warmer and dryer\", \"room 3\"),\n",
|
|
"]\n",
|
|
"\n",
|
|
"for device, conditions, room in devices:\n",
|
|
" session.execute(pstmt, (device, conditions, room))\n",
|
|
"\n",
|
|
"print(\"Sensors inserted successfully.\")\n",
|
|
"\n",
|
|
"# Create data table\n",
|
|
"table_query = \"\"\"\n",
|
|
"CREATE TABLE IF NOT EXISTS iot_data (\n",
|
|
" day text,\n",
|
|
" device text,\n",
|
|
" ts timestamp,\n",
|
|
" co double,\n",
|
|
" humidity double,\n",
|
|
" light boolean,\n",
|
|
" lpg double,\n",
|
|
" motion boolean,\n",
|
|
" smoke double,\n",
|
|
" temp double,\n",
|
|
" PRIMARY KEY ((day, device), ts)\n",
|
|
")\n",
|
|
"WITH COMMENT = 'Data from environmental IoT room sensors. Columns include device identifier, timestamp (ts) of the data collection, carbon monoxide level (co), relative humidity, light presence, LPG concentration, motion detection, smoke concentration, and temperature (temp). Data is partitioned by day and device.';\n",
|
|
"\"\"\"\n",
|
|
"session.execute(table_query)\n",
|
|
"\n",
|
|
"pstmt = session.prepare(\n",
|
|
" \"\"\"\n",
|
|
"INSERT INTO iot_data (day, device, ts, co, humidity, light, lpg, motion, smoke, temp)\n",
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)\n",
|
|
"\"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"def insert_data_batch(name, group):\n",
|
|
" batch = BatchStatement()\n",
|
|
" day, device = name\n",
|
|
" print(f\"Inserting batch for day: {day}, device: {device}\")\n",
|
|
"\n",
|
|
" for _, row in group.iterrows():\n",
|
|
" timestamp = datetime.fromtimestamp(row[\"ts\"], UTC)\n",
|
|
" batch.add(\n",
|
|
" pstmt,\n",
|
|
" (\n",
|
|
" day,\n",
|
|
" row[\"device\"],\n",
|
|
" timestamp,\n",
|
|
" row[\"co\"],\n",
|
|
" row[\"humidity\"],\n",
|
|
" row[\"light\"],\n",
|
|
" row[\"lpg\"],\n",
|
|
" row[\"motion\"],\n",
|
|
" row[\"smoke\"],\n",
|
|
" row[\"temp\"],\n",
|
|
" ),\n",
|
|
" )\n",
|
|
"\n",
|
|
" session.execute(batch)\n",
|
|
"\n",
|
|
"\n",
|
|
"# Convert columns to appropriate types\n",
|
|
"df[\"light\"] = df[\"light\"] == \"true\"\n",
|
|
"df[\"motion\"] = df[\"motion\"] == \"true\"\n",
|
|
"df[\"ts\"] = df[\"ts\"].astype(float)\n",
|
|
"df[\"day\"] = df[\"ts\"].apply(\n",
|
|
" lambda x: datetime.fromtimestamp(x, UTC).strftime(\"%Y-%m-%d\")\n",
|
|
")\n",
|
|
"\n",
|
|
"grouped_df = df.groupby([\"day\", \"device\"])\n",
|
|
"\n",
|
|
"for name, group in grouped_df:\n",
|
|
" insert_data_batch(name, group)\n",
|
|
"\n",
|
|
"print(\"Data load complete\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(session.keyspace)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load the Tools"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Python `import` statements for the demo:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
|
|
"from langchain_community.agent_toolkits.cassandra_database.toolkit import (\n",
|
|
" CassandraDatabaseToolkit,\n",
|
|
")\n",
|
|
"from langchain_community.tools.cassandra_database.prompt import QUERY_PATH_PROMPT\n",
|
|
"from langchain_community.tools.cassandra_database.tool import (\n",
|
|
" GetSchemaCassandraDatabaseTool,\n",
|
|
" GetTableDataCassandraDatabaseTool,\n",
|
|
" QueryCassandraDatabaseTool,\n",
|
|
")\n",
|
|
"from langchain_community.utilities.cassandra_database import CassandraDatabase\n",
|
|
"from langchain_openai import ChatOpenAI"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `CassandraDatabase` object is loaded from `cassio`, though it does accept a `Session`-type parameter as an alternative."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a CassandraDatabase instance\n",
|
|
"db = CassandraDatabase(include_tables=[\"iot_sensors\", \"iot_data\"])\n",
|
|
"\n",
|
|
"# Create the Cassandra Database tools\n",
|
|
"query_tool = QueryCassandraDatabaseTool(db=db)\n",
|
|
"schema_tool = GetSchemaCassandraDatabaseTool(db=db)\n",
|
|
"select_data_tool = GetTableDataCassandraDatabaseTool(db=db)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The tools can be invoked directly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the tools\n",
|
|
"print(\"Executing a CQL query:\")\n",
|
|
"query = \"SELECT * FROM iot_sensors LIMIT 5;\"\n",
|
|
"result = query_tool.run({\"query\": query})\n",
|
|
"print(result)\n",
|
|
"\n",
|
|
"print(\"\\nGetting the schema for a keyspace:\")\n",
|
|
"schema = schema_tool.run({\"keyspace\": keyspace})\n",
|
|
"print(schema)\n",
|
|
"\n",
|
|
"print(\"\\nGetting data from a table:\")\n",
|
|
"table = \"iot_data\"\n",
|
|
"predicate = \"day = '2020-07-14' and device = 'b8:27:eb:bf:9d:51'\"\n",
|
|
"data = select_data_tool.run(\n",
|
|
" {\"keyspace\": keyspace, \"table\": table, \"predicate\": predicate, \"limit\": 5}\n",
|
|
")\n",
|
|
"print(data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Agent Configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import Tool\n",
|
|
"from langchain_experimental.utilities import PythonREPL\n",
|
|
"\n",
|
|
"python_repl = PythonREPL()\n",
|
|
"\n",
|
|
"repl_tool = Tool(\n",
|
|
" name=\"python_repl\",\n",
|
|
" description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n",
|
|
" func=python_repl.run,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain import hub\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
|
|
"toolkit = CassandraDatabaseToolkit(db=db)\n",
|
|
"\n",
|
|
"# context = toolkit.get_context()\n",
|
|
"# tools = toolkit.get_tools()\n",
|
|
"tools = [schema_tool, select_data_tool, repl_tool]\n",
|
|
"\n",
|
|
"input = (\n",
|
|
" QUERY_PATH_PROMPT\n",
|
|
" + f\"\"\"\n",
|
|
"\n",
|
|
"Here is your task: In the {keyspace} keyspace, find the total number of times the temperature of each device has exceeded 23 degrees on July 14, 2020.\n",
|
|
" Create a summary report including the name of the room. Use Pandas if helpful.\n",
|
|
"\"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
|
|
"\n",
|
|
"# messages = [\n",
|
|
"# HumanMessagePromptTemplate.from_template(input),\n",
|
|
"# AIMessage(content=QUERY_PATH_PROMPT),\n",
|
|
"# MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
|
|
"# ]\n",
|
|
"\n",
|
|
"# prompt = ChatPromptTemplate.from_messages(messages)\n",
|
|
"# print(prompt)\n",
|
|
"\n",
|
|
"# Choose the LLM that will drive the agent\n",
|
|
"# Only certain models support this\n",
|
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
|
|
"\n",
|
|
"# Construct the OpenAI Tools agent\n",
|
|
"agent = create_openai_tools_agent(llm, tools, prompt)\n",
|
|
"\n",
|
|
"print(\"Available tools:\")\n",
|
|
"for tool in tools:\n",
|
|
" print(\"\\t\" + tool.name + \" - \" + tool.description + \" - \" + str(tool))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n",
|
|
"\n",
|
|
"response = agent_executor.invoke({\"input\": input})\n",
|
|
"\n",
|
|
"print(response[\"output\"])"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|