diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index 1e8ad771..3a94abeb 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -3,7 +3,7 @@ from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.load_tools import get_all_tool_names, load_tools from langchain.agents.loading import initialize_agent -from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent +from langchain.agents.mrkl.base import MRKLChain, SQLAgent, ZeroShotAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain from langchain.agents.tools import Tool @@ -17,6 +17,7 @@ __all__ = [ "Tool", "initialize_agent", "ZeroShotAgent", + "SQLAgent", "ReActTextWorldAgent", "load_tools", "get_all_tool_names", diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 76af6dc7..5895aeb8 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -4,10 +4,12 @@ from __future__ import annotations import re from typing import Any, Callable, List, NamedTuple, Optional, Tuple +from langchain import LLMChain from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.agents.mrkl.sql_prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.tools import Tool -from langchain.llms.base import BaseLLM +from langchain.llms.base import BaseLLM, BaseCallbackManager from langchain.prompts import PromptTemplate FINAL_ANSWER_ACTION = "Final Answer:" @@ -100,6 +102,53 @@ class ZeroShotAgent(Agent): return get_action_and_input(text) +class SQLAgent(ZeroShotAgent): + @classmethod + def create_prompt( + cls, + tools: List[Tool], + prefix: str = SQL_PREFIX, + suffix: str = SQL_SUFFIX, + input_variables: Optional[List[str]] = None, + ) -> PromptTemplate: + return super().create_prompt(tools, prefix, suffix, input_variables) + + @classmethod + def from_llm_and_sql_tool( + cls, + llm: BaseLLM, + sql_tool: Tool, + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, + ) -> Agent: + """Construct an agent from an LLM and SQL Chain tool.""" + + cls._validate_tool(sql_tool) + llm_chain = LLMChain( + llm=llm, + prompt=cls.create_prompt([sql_tool]), + callback_manager=callback_manager, + ) + return cls(llm_chain=llm_chain, **kwargs) + + @classmethod + def _validate_tool(cls, tool: Tool) -> None: + + if isinstance(tool, List): + raise TypeError("The SQLAgent must be used with only one tool.") + + if tool.func.__self__.__class__.__name__ != "SQLDatabaseChain": + raise ValueError( + "The SQLAgent must be used with an 'SQLDatabaseChain' based tool." + ) + + if tool.description is None: + raise ValueError( + f"Got a tool {tool.name} without a description. For this agent, " + f"a description must always be provided." + ) + + class MRKLChain(AgentExecutor): """Chain that implements the MRKL system. @@ -109,9 +158,8 @@ class MRKLChain(AgentExecutor): from langchain import OpenAI, MRKLChain from langchain.chains.mrkl.base import ChainConfig llm = OpenAI(temperature=0) - prompt = PromptTemplate(...) chains = [...] - mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt) + mrkl = MRKLChain.from_chains(llm=llm, chains=chains) """ @classmethod @@ -157,5 +205,6 @@ class MRKLChain(AgentExecutor): Tool(name=c.action_name, func=c.action, description=c.action_description) for c in chains ] + agent = ZeroShotAgent.from_llm_and_tools(llm, tools) return cls(agent=agent, tools=tools, **kwargs) diff --git a/langchain/agents/mrkl/sql_prompt.py b/langchain/agents/mrkl/sql_prompt.py new file mode 100644 index 00000000..10b2a696 --- /dev/null +++ b/langchain/agents/mrkl/sql_prompt.py @@ -0,0 +1,13 @@ +# flake8: noqa +SQL_PREFIX = """Answer the question as best you can. +You should only use data in the SQL database to answer the query. The answer you return should come directly from the database. If you don't find an answer, say "There is not enough information in the DB to answer the question." +Your first query can be exploratory, to understand the data in the table. As an example, you can query what the first 5 examples of a column are before querying that column. +When possible, don't query exactly but always use 'LIKE' to make your queries more robust. +Finally, be mindful of not repeating queries. + +You have access to the following DB:""" + +SQL_SUFFIX = """Begin! + +Question: {input} +Thought:{agent_scratchpad}"""