langchain/libs/community/langchain_community/tools/sql_database/tool.py
Yuki Oshima 0758da8940
community[patch]: Set default value for _ListSQLDatabaseToolInput tool_input (#20409)
**Description:**

`_ListSQLDatabaseToolInput` raise error if model returns `{}`.
For example, gpt-4-turbo returns `{}` with SQL Agent initialized by
`create_sql_agent`.

So, I set default value `""` for `_ListSQLDatabaseToolInput` tool_input.

This is actually a gpt-4-turbo issue, not a LangChain issue, but I
thought it would be helpful to set a default value `""`.

This problem is discussed in detail in the following Issue.

**Issue:** https://github.com/langchain-ai/langchain/issues/20405

**Dependencies:** none

Sorry, I did not add or change the test code, as tests for this
components was not exist .

However, I have tested the following code based on the [SQL Agent
Document](https://python.langchain.com/docs/use_cases/sql/agents/), to
make sure it works.

```
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
result = agent_executor.invoke("List the total sales per country. Which country's customers spent the most?")
print(result["output"])
```
2024-04-13 15:58:47 -07:00

160 lines
5.2 KiB
Python

# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional, Sequence, Type, Union
from sqlalchemy.engine import Result
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.language_models import BaseLanguageModel
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.tools import BaseTool
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
class BaseSQLDatabaseTool(BaseModel):
"""Base tool for interacting with a SQL database."""
db: SQLDatabase = Field(exclude=True)
class Config(BaseTool.Config):
pass
class _QuerySQLDataBaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct SQL query.")
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name: str = "sql_db_query"
description: str = """
Execute a SQL query against the database and get back the result..
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result]:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)
class _InfoSQLDatabaseToolInput(BaseModel):
table_names: str = Field(
...,
description=(
"A comma-separated list of the table names for which to return the schema. "
"Example input: 'table1, table2, table3'"
),
)
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name: str = "sql_db_schema"
description: str = "Get the schema and sample rows for the specified SQL tables."
args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(
[t.strip() for t in table_names.split(",")]
)
class _ListSQLDataBaseToolInput(BaseModel):
tool_input: str = Field("", description="An empty string")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get a comma-separated list of table names."""
return ", ".join(self.db.get_usable_table_names())
class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: Any = Field(init=False)
name: str = "sql_db_query_checker"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
from langchain.chains.llm import LLMChain
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["dialect", "query"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(
query=query,
dialect=self.db.dialect,
callbacks=run_manager.get_child() if run_manager else None,
)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(
query=query,
dialect=self.db.dialect,
callbacks=run_manager.get_child() if run_manager else None,
)