diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 50bf62ff..aed79dba 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -15,11 +15,16 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AgentAction, AgentFinish, BaseMessage, BaseOutputParser +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseLanguageModel, + BaseMessage, + BaseOutputParser, +) from langchain.tools.base import BaseTool logger = logging.getLogger() @@ -365,7 +370,7 @@ class Agent(BaseSingleActionAgent): @classmethod def from_llm_and_tools( cls, - llm: BaseLLM, + llm: BaseLanguageModel, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index e47329d2..23995797 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -8,8 +8,8 @@ from langchain.agents.agent import Agent from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.llms import BaseLLM from langchain.prompts import PromptTemplate +from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool @@ -89,7 +89,7 @@ class ConversationalAgent(Agent): @classmethod def from_llm_and_tools( cls, - llm: BaseLLM, + llm: BaseLanguageModel, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index 825f27d8..04277c60 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -4,13 +4,13 @@ from typing import Any, Optional, Sequence from langchain.agents.agent import AgentExecutor from langchain.agents.loading import AGENT_TO_CLASS, load_agent from langchain.callbacks.base import BaseCallbackManager -from langchain.llms.base import BaseLLM +from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool def initialize_agent( tools: Sequence[BaseTool], - llm: BaseLLM, + llm: BaseLanguageModel, agent: Optional[str] = None, callback_manager: Optional[BaseCallbackManager] = None, agent_path: Optional[str] = None, diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 2667936b..a388a120 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -9,8 +9,8 @@ from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate +from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool FINAL_ANSWER_ACTION = "Final Answer:" @@ -100,7 +100,7 @@ class ZeroShotAgent(Agent): @classmethod def from_llm_and_tools( cls, - llm: BaseLLM, + llm: BaseLanguageModel, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, @@ -155,7 +155,7 @@ class MRKLChain(AgentExecutor): @classmethod def from_chains( - cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any + cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any ) -> AgentExecutor: """User friendly way to initialize the MRKL chain.