From 705596b46a9ad2a5c6ec28a85678493d2dfbea14 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 13 Apr 2023 22:07:58 -0700 Subject: [PATCH] Harrison/fix create sql agent (#2870) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Timothé Pearce --- langchain/document_loaders/git.py | 2 +- langchain/tools/sql_database/tool.py | 30 +++++++++++++++------------- tests/unit_tests/agents/test_sql.py | 18 +++++++++++++++++ tests/unit_tests/llms/fake_llm.py | 27 +++++++++++++++++++++++-- 4 files changed, 60 insertions(+), 17 deletions(-) create mode 100644 tests/unit_tests/agents/test_sql.py diff --git a/langchain/document_loaders/git.py b/langchain/document_loaders/git.py index 39a9235e26..155767629e 100644 --- a/langchain/document_loaders/git.py +++ b/langchain/document_loaders/git.py @@ -28,7 +28,7 @@ class GitLoader(BaseLoader): def load(self) -> List[Document]: try: - from git import Blob, Repo + from git import Blob, Repo # type: ignore except ImportError as ex: raise ImportError( "Could not import git python package. " diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index a9ac6981b5..3921b43a2f 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -1,6 +1,7 @@ # flake8: noqa """Tools for interacting with a SQL database.""" -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, Extra, Field, validator, root_validator +from typing import Any, Dict from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate @@ -81,28 +82,29 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool): template: str = QUERY_CHECKER llm: BaseLLM - llm_chain: LLMChain = Field( - default_factory=lambda: LLMChain( - llm=QueryCheckerTool.llm, - prompt=PromptTemplate( - template=QueryCheckerTool.template, input_variables=["query", "dialect"] - ), - ) - ) + llm_chain: LLMChain = Field(init=False) name = "query_checker_sql_db" description = """ Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with query_sql_db! """ - @validator("llm_chain") - def validate_llm_chain_input_variables(cls, llm_chain: LLMChain) -> LLMChain: - """Make sure the LLM chain has the correct input variables.""" - if llm_chain.prompt.input_variables != ["query", "dialect"]: + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, input_variables=["query", "dialect"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: raise ValueError( "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" ) - return llm_chain + + return values def _run(self, query: str) -> str: """Use the LLM to check the query.""" diff --git a/tests/unit_tests/agents/test_sql.py b/tests/unit_tests/agents/test_sql.py new file mode 100644 index 0000000000..89b8f90df2 --- /dev/null +++ b/tests/unit_tests/agents/test_sql.py @@ -0,0 +1,18 @@ +from langchain.agents import create_sql_agent +from langchain.agents.agent_toolkits import SQLDatabaseToolkit +from langchain.sql_database import SQLDatabase +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_create_sql_agent() -> None: + db = SQLDatabase.from_uri("sqlite:///:memory:") + queries = {"foo": "Final Answer: baz"} + llm = FakeLLM(queries=queries, sequential_responses=True) + toolkit = SQLDatabaseToolkit(db=db, llm=llm) + + agent_executor = create_sql_agent( + llm=llm, + toolkit=toolkit, + ) + + assert agent_executor.run("hello") == "baz" diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index 263bc2b630..cc12a7cab7 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -1,5 +1,7 @@ """Fake LLM wrapper for testing purposes.""" -from typing import Any, List, Mapping, Optional +from typing import Any, List, Mapping, Optional, cast + +from pydantic import validator from langchain.llms.base import LLM @@ -8,6 +10,18 @@ class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" queries: Optional[Mapping] = None + sequential_responses: Optional[bool] = False + response_index: int = 0 + + @validator("queries", always=True) + def check_queries_required( + cls, queries: Optional[Mapping], values: Mapping[str, Any] + ) -> Optional[Mapping]: + if values.get("sequential_response") and not queries: + raise ValueError( + "queries is required when sequential_response is set to True" + ) + return queries @property def _llm_type(self) -> str: @@ -15,7 +29,9 @@ class FakeLLM(LLM): return "fake" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - """First try to lookup in queries, else return 'foo' or 'bar'.""" + if self.sequential_responses: + return self._get_next_response_in_sequence + if self.queries is not None: return self.queries[prompt] if stop is None: @@ -26,3 +42,10 @@ class FakeLLM(LLM): @property def _identifying_params(self) -> Mapping[str, Any]: return {} + + @property + def _get_next_response_in_sequence(self) -> str: + queries = cast(Mapping, self.queries) + response = queries[list(queries.keys())[self.response_index]] + self.response_index = self.response_index + 1 + return response