SqlDatabaseToolkit should have custom llm for QueryChecke… (#2676)

…rTool (#2655)

---------

Co-authored-by: Rushabh Agarwal <26388764+rushout09@users.noreply.github.com>
fix_agent_callbacks
Ankush Gola 1 year ago committed by GitHub
parent 8d3b059332
commit e23a596a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,7 +44,7 @@ def create_sql_agent(
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=toolkit.get_tools(),
tools=tools,
verbose=verbose,
max_iterations=max_iterations,
early_stopping_method=early_stopping_method,

@ -4,6 +4,8 @@ from typing import List
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.sql_database import SQLDatabase
from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import (
@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases."""
db: SQLDatabase = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
@property
def dialect(self) -> str:
@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit):
QuerySQLDataBaseTool(db=self.db),
InfoSQLDatabaseTool(db=self.db),
ListSQLDatabaseTool(db=self.db),
QueryCheckerTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]

@ -3,9 +3,9 @@
from pydantic import BaseModel, Extra, Field, validator
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase
from langchain.llms.base import BaseLLM
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
@ -80,11 +80,12 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLLM
llm_chain: LLMChain = Field(
default_factory=lambda: LLMChain(
llm=OpenAI(temperature=0),
llm=QueryCheckerTool.llm,
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
template=QueryCheckerTool.template, input_variables=["query", "dialect"]
),
)
)

Loading…
Cancel
Save