@ -4,10 +4,12 @@ from __future__ import annotations
import re
from typing import Any , Callable , List , NamedTuple , Optional , Tuple
from langchain import LLMChain
from langchain . agents . agent import Agent , AgentExecutor
from langchain . agents . mrkl . prompt import FORMAT_INSTRUCTIONS , PREFIX , SUFFIX
from langchain . agents . mrkl . sql_prompt import SQL_PREFIX , SQL_SUFFIX
from langchain . agents . tools import Tool
from langchain . llms . base import BaseLLM
from langchain . llms . base import BaseLLM , BaseCallbackManager
from langchain . prompts import PromptTemplate
FINAL_ANSWER_ACTION = " Final Answer: "
@ -100,6 +102,53 @@ class ZeroShotAgent(Agent):
return get_action_and_input ( text )
class SQLAgent ( ZeroShotAgent ) :
@classmethod
def create_prompt (
cls ,
tools : List [ Tool ] ,
prefix : str = SQL_PREFIX ,
suffix : str = SQL_SUFFIX ,
input_variables : Optional [ List [ str ] ] = None ,
) - > PromptTemplate :
return super ( ) . create_prompt ( tools , prefix , suffix , input_variables )
@classmethod
def from_llm_and_sql_tool (
cls ,
llm : BaseLLM ,
sql_tool : Tool ,
callback_manager : Optional [ BaseCallbackManager ] = None ,
* * kwargs : Any ,
) - > Agent :
""" Construct an agent from an LLM and SQL Chain tool. """
cls . _validate_tool ( sql_tool )
llm_chain = LLMChain (
llm = llm ,
prompt = cls . create_prompt ( [ sql_tool ] ) ,
callback_manager = callback_manager ,
)
return cls ( llm_chain = llm_chain , * * kwargs )
@classmethod
def _validate_tool ( cls , tool : Tool ) - > None :
if isinstance ( tool , List ) :
raise TypeError ( " The SQLAgent must be used with only one tool. " )
if tool . func . __self__ . __class__ . __name__ != " SQLDatabaseChain " :
raise ValueError (
" The SQLAgent must be used with an ' SQLDatabaseChain ' based tool. "
)
if tool . description is None :
raise ValueError (
f " Got a tool { tool . name } without a description. For this agent, "
f " a description must always be provided. "
)
class MRKLChain ( AgentExecutor ) :
""" Chain that implements the MRKL system.
@ -109,9 +158,8 @@ class MRKLChain(AgentExecutor):
from langchain import OpenAI , MRKLChain
from langchain . chains . mrkl . base import ChainConfig
llm = OpenAI ( temperature = 0 )
prompt = PromptTemplate ( . . . )
chains = [ . . . ]
mrkl = MRKLChain . from_chains ( llm = llm , prompt= prompt )
mrkl = MRKLChain . from_chains ( llm = llm , chains= chains )
"""
@classmethod
@ -157,5 +205,6 @@ class MRKLChain(AgentExecutor):
Tool ( name = c . action_name , func = c . action , description = c . action_description )
for c in chains
]
agent = ZeroShotAgent . from_llm_and_tools ( llm , tools )
return cls ( agent = agent , tools = tools , * * kwargs )