mirror of https://github.com/hwchase17/langchain
tools refactor (#2961)
Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>pull/3060/head
parent
7a8c935b90
commit
db968284f8
@ -0,0 +1,190 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Tool Input Schema\n",
|
||||||
|
"\n",
|
||||||
|
"By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.agents import initialize_agent, Tool\n",
|
||||||
|
"from langchain.agents import AgentType\n",
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.tools.python.tool import PythonREPLTool\n",
|
||||||
|
"from pydantic import BaseModel\n",
|
||||||
|
"from pydantic import Field, root_validator\n",
|
||||||
|
"from typing import Dict, Any"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = OpenAI(temperature=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class ToolInputModel(BaseModel):\n",
|
||||||
|
" query: str = Field(...)\n",
|
||||||
|
" \n",
|
||||||
|
" @root_validator\n",
|
||||||
|
" def validate_query(cls, values: Dict[str, Any]) -> Dict:\n",
|
||||||
|
" # Note: this is NOT a safe REPL! This is used for instructive purposes only\n",
|
||||||
|
" if \"import os\" in values[\"query\"]:\n",
|
||||||
|
" raise ValueError(\"'import os' not permitted in this python REPL.\")\n",
|
||||||
|
" return values\n",
|
||||||
|
" \n",
|
||||||
|
"tool = PythonREPLTool(args_schema=ToolInputModel)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n",
|
||||||
|
"Action: Python REPL\n",
|
||||||
|
"Action Input: def add_two_numbers(a, b):\n",
|
||||||
|
" return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n",
|
||||||
|
"Action: Python REPL\n",
|
||||||
|
"Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n",
|
||||||
|
"\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||||
|
"Final Answer: 4\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'4'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# This will succeed, since there aren't any arguments that will be triggered during validation\n",
|
||||||
|
"agent.run(\"Run a python function that adds 2 and 2\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n",
|
||||||
|
"Action: Python REPL\n",
|
||||||
|
"Action Input: import os; print(os.listdir())\u001b[0m"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "ValidationError",
|
||||||
|
"evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# This will fail, because the attempt to import os will trigger a validation error\n",
|
||||||
|
"agent.run(\"Run a python function that imports os and lists the dir\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileInput(BaseModel):
|
||||||
|
"""Input for ReadFileTool."""
|
||||||
|
|
||||||
|
file_path: str = Field(..., description="name of file")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(BaseTool):
|
||||||
|
name: str = "read_file"
|
||||||
|
tool_args: Type[BaseModel] = ReadFileInput
|
||||||
|
description: str = "Read file from disk"
|
||||||
|
|
||||||
|
def _run(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
return "Error: " + str(e)
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
# TODO: Add aiofiles method
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileInput(BaseModel):
|
||||||
|
"""Input for WriteFileTool."""
|
||||||
|
|
||||||
|
file_path: str = Field(..., description="name of file")
|
||||||
|
text: str = Field(..., description="text to write to file")
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileTool(BaseTool):
|
||||||
|
name: str = "write_file"
|
||||||
|
tool_args: Type[BaseModel] = WriteFileInput
|
||||||
|
description: str = "Write file to disk"
|
||||||
|
|
||||||
|
def _run(self, file_path: str, text: str) -> str:
|
||||||
|
try:
|
||||||
|
directory = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(directory) and directory:
|
||||||
|
os.makedirs(directory)
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(text)
|
||||||
|
return "File written to successfully."
|
||||||
|
except Exception as e:
|
||||||
|
return "Error: " + str(e)
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
# TODO: Add aiofiles method
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,100 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.requests import TextRequestsWrapper
|
||||||
|
from langchain.tools.requests.tool import (
|
||||||
|
RequestsDeleteTool,
|
||||||
|
RequestsGetTool,
|
||||||
|
RequestsPatchTool,
|
||||||
|
RequestsPostTool,
|
||||||
|
RequestsPutTool,
|
||||||
|
_parse_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockTextRequestsWrapper(TextRequestsWrapper):
|
||||||
|
@staticmethod
|
||||||
|
def get(url: str, **kwargs: Any) -> str:
|
||||||
|
return "get_response"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget(url: str, **kwargs: Any) -> str:
|
||||||
|
return "aget_response"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"post {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def apost(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"apost {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def patch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"patch {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def apatch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"apatch {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def put(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"put {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aput(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
|
||||||
|
return f"aput {str(data)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete(url: str, **kwargs: Any) -> str:
|
||||||
|
return "delete_response"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def adelete(url: str, **kwargs: Any) -> str:
|
||||||
|
return "adelete_response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_requests_wrapper() -> TextRequestsWrapper:
|
||||||
|
return _MockTextRequestsWrapper()
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_input() -> None:
|
||||||
|
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||||
|
expected_output = {"url": "https://example.com", "data": {"key": "value"}}
|
||||||
|
assert _parse_input(input_text) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||||
|
tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper)
|
||||||
|
assert tool.run("https://example.com") == "get_response"
|
||||||
|
assert asyncio.run(tool.arun("https://example.com")) == "aget_response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||||
|
tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper)
|
||||||
|
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||||
|
assert tool.run(input_text) == "post {'key': 'value'}"
|
||||||
|
assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||||
|
tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper)
|
||||||
|
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||||
|
assert tool.run(input_text) == "patch {'key': 'value'}"
|
||||||
|
assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||||
|
tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper)
|
||||||
|
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||||
|
assert tool.run(input_text) == "put {'key': 'value'}"
|
||||||
|
assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||||
|
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
|
||||||
|
assert tool.run("https://example.com") == "delete_response"
|
||||||
|
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
|
Loading…
Reference in New Issue