forked from Archives/langchain
Fix type annotation for QueryCheckerTool.llm
(#3237)
Currently `langchain.tools.sql_database.tool.QueryCheckerTool` has a field `llm` with type `BaseLLM`. This breaks initialization for some LLMs. For example, trying to use it with GPT4: ```python from langchain.sql_database import SQLDatabase from langchain.chat_models import ChatOpenAI from langchain.tools.sql_database.tool import QueryCheckerTool db = SQLDatabase.from_uri("some_db_uri") llm = ChatOpenAI(model_name="gpt-4") tool = QueryCheckerTool(db=db, llm=llm) # pydantic.error_wrappers.ValidationError: 1 validation error for QueryCheckerTool # llm # Can't instantiate abstract class BaseLLM with abstract methods _agenerate, _generate, _llm_type (type=type_error) ``` Seems like much of the rest of the codebase has switched from `BaseLLM` to `BaseLanguageModel`. This PR makes the change for QueryCheckerTool as well Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
parent
46542dc774
commit
d7942a9f19
@ -6,7 +6,7 @@ from typing import Any, Dict
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.schema import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||||
|
|
||||||
template: str = QUERY_CHECKER
|
template: str = QUERY_CHECKER
|
||||||
llm: BaseLLM
|
llm: BaseLanguageModel
|
||||||
llm_chain: LLMChain = Field(init=False)
|
llm_chain: LLMChain = Field(init=False)
|
||||||
name = "query_checker_sql_db"
|
name = "query_checker_sql_db"
|
||||||
description = """
|
description = """
|
||||||
|
Loading…
Reference in New Issue
Block a user