SqlDatabaseToolkit should have custom llm for QueryChecke… (#2676)

…rTool (#2655)

---------

Co-authored-by: Rushabh Agarwal <26388764+rushout09@users.noreply.github.com>
fix_agent_callbacks
Ankush Gola 1 year ago committed by GitHub
parent 8d3b059332
commit e23a596a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,7 +44,7 @@ def create_sql_agent(
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent, agent=agent,
tools=toolkit.get_tools(), tools=tools,
verbose=verbose, verbose=verbose,
max_iterations=max_iterations, max_iterations=max_iterations,
early_stopping_method=early_stopping_method, early_stopping_method=early_stopping_method,

@ -4,6 +4,8 @@ from typing import List
from pydantic import Field from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import ( from langchain.tools.sql_database.tool import (
@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases.""" """Toolkit for interacting with SQL databases."""
db: SQLDatabase = Field(exclude=True) db: SQLDatabase = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
@property @property
def dialect(self) -> str: def dialect(self) -> str:
@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit):
QuerySQLDataBaseTool(db=self.db), QuerySQLDataBaseTool(db=self.db),
InfoSQLDatabaseTool(db=self.db), InfoSQLDatabaseTool(db=self.db),
ListSQLDatabaseTool(db=self.db), ListSQLDatabaseTool(db=self.db),
QueryCheckerTool(db=self.db), QueryCheckerTool(db=self.db, llm=self.llm),
] ]

@ -3,9 +3,9 @@
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, validator
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
from langchain.llms.base import BaseLLM
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER from langchain.tools.sql_database.prompt import QUERY_CHECKER
@ -80,11 +80,12 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLLM
llm_chain: LLMChain = Field( llm_chain: LLMChain = Field(
default_factory=lambda: LLMChain( default_factory=lambda: LLMChain(
llm=OpenAI(temperature=0), llm=QueryCheckerTool.llm,
prompt=PromptTemplate( prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"] template=QueryCheckerTool.template, input_variables=["query", "dialect"]
), ),
) )
) )

Loading…
Cancel
Save