Fix type annotation for QueryCheckerTool.llm (#3237)

Currently `langchain.tools.sql_database.tool.QueryCheckerTool` has a
field `llm` with type `BaseLLM`. This breaks initialization for some
LLMs. For example, trying to use it with GPT4:

```python
from langchain.sql_database import SQLDatabase
from langchain.chat_models import ChatOpenAI
from langchain.tools.sql_database.tool import QueryCheckerTool


db = SQLDatabase.from_uri("some_db_uri")
llm = ChatOpenAI(model_name="gpt-4")
tool = QueryCheckerTool(db=db, llm=llm)

# pydantic.error_wrappers.ValidationError: 1 validation error for QueryCheckerTool
# llm
#   Can't instantiate abstract class BaseLLM with abstract methods _agenerate, _generate, _llm_type (type=type_error)
```

Seems like much of the rest of the codebase has switched from `BaseLLM`
to `BaseLanguageModel`. This PR makes the change for QueryCheckerTool as
well

Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
Zach Jones 2023-04-20 21:50:59 -04:00 committed by GitHub
parent 46542dc774
commit d7942a9f19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,7 +6,7 @@ from typing import Any, Dict
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
from langchain.llms.base import BaseLLM from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER from langchain.tools.sql_database.prompt import QUERY_CHECKER
@ -81,7 +81,7 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLLM llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False) llm_chain: LLMChain = Field(init=False)
name = "query_checker_sql_db" name = "query_checker_sql_db"
description = """ description = """