Harrison/fix create sql agent (#2870)

Co-authored-by: Timothé Pearce <timothe.pearce@gmail.com>
This commit is contained in:
Harrison Chase 2023-04-13 22:07:58 -07:00 committed by GitHub
parent 8a98e5b50b
commit 705596b46a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 17 deletions

View File

@ -28,7 +28,7 @@ class GitLoader(BaseLoader):
def load(self) -> List[Document]: def load(self) -> List[Document]:
try: try:
from git import Blob, Repo from git import Blob, Repo # type: ignore
except ImportError as ex: except ImportError as ex:
raise ImportError( raise ImportError(
"Could not import git python package. " "Could not import git python package. "

View File

@ -1,6 +1,7 @@
# flake8: noqa # flake8: noqa
"""Tools for interacting with a SQL database.""" """Tools for interacting with a SQL database."""
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, validator, root_validator
from typing import Any, Dict
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -81,28 +82,29 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLLM llm: BaseLLM
llm_chain: LLMChain = Field( llm_chain: LLMChain = Field(init=False)
default_factory=lambda: LLMChain(
llm=QueryCheckerTool.llm,
prompt=PromptTemplate(
template=QueryCheckerTool.template, input_variables=["query", "dialect"]
),
)
)
name = "query_checker_sql_db" name = "query_checker_sql_db"
description = """ description = """
Use this tool to double check if your query is correct before executing it. Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db! Always use this tool before executing a query with query_sql_db!
""" """
@validator("llm_chain") @root_validator(pre=True)
def validate_llm_chain_input_variables(cls, llm_chain: LLMChain) -> LLMChain: def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Make sure the LLM chain has the correct input variables.""" if "llm_chain" not in values:
if llm_chain.prompt.input_variables != ["query", "dialect"]: values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
raise ValueError( raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
) )
return llm_chain
return values
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
"""Use the LLM to check the query.""" """Use the LLM to check the query."""

View File

@ -0,0 +1,18 @@
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_create_sql_agent() -> None:
db = SQLDatabase.from_uri("sqlite:///:memory:")
queries = {"foo": "Final Answer: baz"}
llm = FakeLLM(queries=queries, sequential_responses=True)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
llm=llm,
toolkit=toolkit,
)
assert agent_executor.run("hello") == "baz"

View File

@ -1,5 +1,7 @@
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional, cast
from pydantic import validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -8,6 +10,18 @@ class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None queries: Optional[Mapping] = None
sequential_responses: Optional[bool] = False
response_index: int = 0
@validator("queries", always=True)
def check_queries_required(
cls, queries: Optional[Mapping], values: Mapping[str, Any]
) -> Optional[Mapping]:
if values.get("sequential_response") and not queries:
raise ValueError(
"queries is required when sequential_response is set to True"
)
return queries
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -15,7 +29,9 @@ class FakeLLM(LLM):
return "fake" return "fake"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'.""" if self.sequential_responses:
return self._get_next_response_in_sequence
if self.queries is not None: if self.queries is not None:
return self.queries[prompt] return self.queries[prompt]
if stop is None: if stop is None:
@ -26,3 +42,10 @@ class FakeLLM(LLM):
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {} return {}
@property
def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries)
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response