From 502ba6a0befa23912ef2bef5bd9fae4039cdb438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20SCHILTZ?= Date: Sat, 29 Apr 2023 06:19:01 +0200 Subject: [PATCH] Fix type annotation for SQLDatabaseToolkit.llm (#3581) Currently `langchain.agents.agent_toolkits.SQLDatabaseToolkit` has a field `llm` with type `BaseLLM`. This breaks initialization for some LLMs. For example, trying to use it with GPT4: ``` from langchain.sql_database import SQLDatabase from langchain.chat_models import ChatOpenAI from langchain.agents.agent_toolkits import SQLDatabaseToolkit db = SQLDatabase.from_uri("some_db_uri") llm = ChatOpenAI(model_name="gpt-4") toolkit = SQLDatabaseToolkit(db=db, llm=llm) # pydantic.error_wrappers.ValidationError: 1 validation error for SQLDatabaseToolkit # llm # Can't instantiate abstract class BaseLLM with abstract methods _agenerate, _generate, _llm_type (type=type_error) ``` Seems like much of the rest of the codebase has switched from BaseLLM to BaseLanguageModel. This PR makes the change for SQLDatabaseToolkit as well --- langchain/agents/agent_toolkits/sql/toolkit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 491d2460..91c3de0b 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -4,7 +4,7 @@ from typing import List from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.llms.base import BaseLLM +from langchain.schema import BaseLanguageModel from langchain.sql_database import SQLDatabase from langchain.tools import BaseTool from langchain.tools.sql_database.tool import ( @@ -19,7 +19,7 @@ class SQLDatabaseToolkit(BaseToolkit): """Toolkit for interacting with SQL databases.""" db: SQLDatabase = Field(exclude=True) - llm: BaseLLM = Field(exclude=True) + llm: BaseLanguageModel = Field(exclude=True) @property def dialect(self) -> str: