diff --git a/langchain/agents/agent_toolkits/sql/base.py b/langchain/agents/agent_toolkits/sql/base.py index b259f5af..6a24a430 100644 --- a/langchain/agents/agent_toolkits/sql/base.py +++ b/langchain/agents/agent_toolkits/sql/base.py @@ -44,7 +44,7 @@ def create_sql_agent( agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) return AgentExecutor.from_agent_and_tools( agent=agent, - tools=toolkit.get_tools(), + tools=tools, verbose=verbose, max_iterations=max_iterations, early_stopping_method=early_stopping_method, diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 2b662d6b..1e32cb64 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -4,6 +4,8 @@ from typing import List from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.llms.base import BaseLLM +from langchain.llms.openai import OpenAI from langchain.sql_database import SQLDatabase from langchain.tools import BaseTool from langchain.tools.sql_database.tool import ( @@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit): """Toolkit for interacting with SQL databases.""" db: SQLDatabase = Field(exclude=True) + llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) @property def dialect(self) -> str: @@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit): QuerySQLDataBaseTool(db=self.db), InfoSQLDatabaseTool(db=self.db), ListSQLDatabaseTool(db=self.db), - QueryCheckerTool(db=self.db), + QueryCheckerTool(db=self.db, llm=self.llm), ] diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index d4b72cd4..a9ac6981 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -3,9 +3,9 @@ from pydantic import BaseModel, Extra, Field, validator from langchain.chains.llm import LLMChain -from langchain.llms.openai import OpenAI from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase +from langchain.llms.base import BaseLLM from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER @@ -80,11 +80,12 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool): Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" template: str = QUERY_CHECKER + llm: BaseLLM llm_chain: LLMChain = Field( default_factory=lambda: LLMChain( - llm=OpenAI(temperature=0), + llm=QueryCheckerTool.llm, prompt=PromptTemplate( - template=QUERY_CHECKER, input_variables=["query", "dialect"] + template=QueryCheckerTool.template, input_variables=["query", "dialect"] ), ) )