diff --git a/langchain/agents/agent_toolkits/csv/base.py b/langchain/agents/agent_toolkits/csv/base.py index 9bac5436..b74ee729 100644 --- a/langchain/agents/agent_toolkits/csv/base.py +++ b/langchain/agents/agent_toolkits/csv/base.py @@ -3,11 +3,14 @@ from typing import Any, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent -from langchain.llms.base import BaseLLM +from langchain.base_language import BaseLanguageModel def create_csv_agent( - llm: BaseLLM, path: str, pandas_kwargs: Optional[dict] = None, **kwargs: Any + llm: BaseLanguageModel, + path: str, + pandas_kwargs: Optional[dict] = None, + **kwargs: Any ) -> AgentExecutor: """Create csv agent by loading to a dataframe and using pandas agent.""" import pandas as pd diff --git a/langchain/agents/agent_toolkits/json/base.py b/langchain/agents/agent_toolkits/json/base.py index e46e994f..3a6f58e5 100644 --- a/langchain/agents/agent_toolkits/json/base.py +++ b/langchain/agents/agent_toolkits/json/base.py @@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM def create_json_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: JsonToolkit, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = JSON_PREFIX, diff --git a/langchain/agents/agent_toolkits/nla/tool.py b/langchain/agents/agent_toolkits/nla/tool.py index d4d7e96a..2e79b678 100644 --- a/langchain/agents/agent_toolkits/nla/tool.py +++ b/langchain/agents/agent_toolkits/nla/tool.py @@ -4,8 +4,8 @@ from typing import Any, Optional from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.chains.api.openapi.chain import OpenAPIEndpointChain -from langchain.llms.base import BaseLLM from langchain.requests import Requests from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec @@ -32,7 +32,7 @@ class NLATool(Tool): @classmethod def from_llm_and_method( cls, - llm: BaseLLM, + llm: BaseLanguageModel, path: str, method: str, spec: OpenAPISpec, diff --git a/langchain/agents/agent_toolkits/nla/toolkit.py b/langchain/agents/agent_toolkits/nla/toolkit.py index d1104b62..42ac0113 100644 --- a/langchain/agents/agent_toolkits/nla/toolkit.py +++ b/langchain/agents/agent_toolkits/nla/toolkit.py @@ -7,7 +7,7 @@ from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.nla.tool import NLATool -from langchain.llms.base import BaseLLM +from langchain.base_language import BaseLanguageModel from langchain.requests import Requests from langchain.tools.base import BaseTool from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec @@ -26,7 +26,7 @@ class NLAToolkit(BaseToolkit): @staticmethod def _get_http_operation_tools( - llm: BaseLLM, + llm: BaseLanguageModel, spec: OpenAPISpec, requests: Optional[Requests] = None, verbose: bool = False, @@ -53,7 +53,7 @@ class NLAToolkit(BaseToolkit): @classmethod def from_llm_and_spec( cls, - llm: BaseLLM, + llm: BaseLanguageModel, spec: OpenAPISpec, requests: Optional[Requests] = None, verbose: bool = False, @@ -68,7 +68,7 @@ class NLAToolkit(BaseToolkit): @classmethod def from_llm_and_url( cls, - llm: BaseLLM, + llm: BaseLanguageModel, open_api_url: str, requests: Optional[Requests] = None, verbose: bool = False, @@ -83,7 +83,7 @@ class NLAToolkit(BaseToolkit): @classmethod def from_llm_and_ai_plugin( cls, - llm: BaseLLM, + llm: BaseLanguageModel, ai_plugin: AIPlugin, requests: Optional[Requests] = None, verbose: bool = False, @@ -103,7 +103,7 @@ class NLAToolkit(BaseToolkit): @classmethod def from_llm_and_ai_plugin_url( cls, - llm: BaseLLM, + llm: BaseLanguageModel, ai_plugin_url: str, requests: Optional[Requests] = None, verbose: bool = False, diff --git a/langchain/agents/agent_toolkits/openapi/base.py b/langchain/agents/agent_toolkits/openapi/base.py index 7e18c70b..60a7aaac 100644 --- a/langchain/agents/agent_toolkits/openapi/base.py +++ b/langchain/agents/agent_toolkits/openapi/base.py @@ -9,13 +9,13 @@ from langchain.agents.agent_toolkits.openapi.prompt import ( from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM def create_openapi_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: OpenAPIToolkit, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = OPENAPI_PREFIX, diff --git a/langchain/agents/agent_toolkits/openapi/toolkit.py b/langchain/agents/agent_toolkits/openapi/toolkit.py index 3ae16526..8c10dad8 100644 --- a/langchain/agents/agent_toolkits/openapi/toolkit.py +++ b/langchain/agents/agent_toolkits/openapi/toolkit.py @@ -9,7 +9,7 @@ from langchain.agents.agent_toolkits.json.base import create_json_agent from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION from langchain.agents.tools import Tool -from langchain.llms.base import BaseLLM +from langchain.base_language import BaseLanguageModel from langchain.requests import TextRequestsWrapper from langchain.tools import BaseTool from langchain.tools.json.tool import JsonSpec @@ -57,7 +57,7 @@ class OpenAPIToolkit(BaseToolkit): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, json_spec: JsonSpec, requests_wrapper: TextRequestsWrapper, **kwargs: Any, diff --git a/langchain/agents/agent_toolkits/pandas/base.py b/langchain/agents/agent_toolkits/pandas/base.py index 1d880e32..200bf453 100644 --- a/langchain/agents/agent_toolkits/pandas/base.py +++ b/langchain/agents/agent_toolkits/pandas/base.py @@ -4,14 +4,14 @@ from typing import Any, Dict, List, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.tools.python.tool import PythonAstREPLTool def create_pandas_dataframe_agent( - llm: BaseLLM, + llm: BaseLanguageModel, df: Any, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, diff --git a/langchain/agents/agent_toolkits/powerbi/base.py b/langchain/agents/agent_toolkits/powerbi/base.py index 578b58e0..c7d8a117 100644 --- a/langchain/agents/agent_toolkits/powerbi/base.py +++ b/langchain/agents/agent_toolkits/powerbi/base.py @@ -9,14 +9,14 @@ from langchain.agents.agent_toolkits.powerbi.prompt import ( from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.utilities.powerbi import PowerBIDataset def create_pbi_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: Optional[PowerBIToolkit], powerbi: Optional[PowerBIDataset] = None, callback_manager: Optional[BaseCallbackManager] = None, diff --git a/langchain/agents/agent_toolkits/python/base.py b/langchain/agents/agent_toolkits/python/base.py index a4663898..2db17642 100644 --- a/langchain/agents/agent_toolkits/python/base.py +++ b/langchain/agents/agent_toolkits/python/base.py @@ -5,14 +5,14 @@ from typing import Any, Dict, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.python.prompt import PREFIX from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.tools.python.tool import PythonREPLTool def create_python_agent( - llm: BaseLLM, + llm: BaseLanguageModel, tool: PythonREPLTool, callback_manager: Optional[BaseCallbackManager] = None, verbose: bool = False, diff --git a/langchain/agents/agent_toolkits/sql/base.py b/langchain/agents/agent_toolkits/sql/base.py index 697c8dbd..784d3155 100644 --- a/langchain/agents/agent_toolkits/sql/base.py +++ b/langchain/agents/agent_toolkits/sql/base.py @@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM def create_sql_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: SQLDatabaseToolkit, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = SQL_PREFIX, diff --git a/langchain/agents/agent_toolkits/vectorstore/base.py b/langchain/agents/agent_toolkits/vectorstore/base.py index 52497625..c3fd97e8 100644 --- a/langchain/agents/agent_toolkits/vectorstore/base.py +++ b/langchain/agents/agent_toolkits/vectorstore/base.py @@ -8,13 +8,13 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import ( VectorStoreToolkit, ) from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM def create_vectorstore_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: VectorStoreToolkit, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, @@ -42,7 +42,7 @@ def create_vectorstore_agent( def create_vectorstore_router_agent( - llm: BaseLLM, + llm: BaseLanguageModel, toolkit: VectorStoreRouterToolkit, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = ROUTER_PREFIX, diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index be8bba2b..46cab1b0 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional, Callable, Tuple from mypy_extensions import Arg, KwArg from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs from langchain.chains.api.base import APIChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.pal.base import PALChain -from langchain.llms.base import BaseLLM from langchain.requests import TextRequestsWrapper from langchain.tools.arxiv.tool import ArxivQueryRun from langchain.tools.base import BaseTool @@ -32,8 +32,6 @@ from langchain.tools.shell.tool import ShellTool from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.utilities import ArxivAPIWrapper -from langchain.utilities.apify import ApifyWrapper -from langchain.utilities.bash import BashProcess from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper @@ -85,7 +83,7 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { } -def _get_pal_math(llm: BaseLLM) -> BaseTool: +def _get_pal_math(llm: BaseLanguageModel) -> BaseTool: return Tool( name="PAL-MATH", description="A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.", @@ -93,7 +91,7 @@ def _get_pal_math(llm: BaseLLM) -> BaseTool: ) -def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool: +def _get_pal_colored_objects(llm: BaseLanguageModel) -> BaseTool: return Tool( name="PAL-COLOR-OBJ", description="A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.", @@ -101,7 +99,7 @@ def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool: ) -def _get_llm_math(llm: BaseLLM) -> BaseTool: +def _get_llm_math(llm: BaseLanguageModel) -> BaseTool: return Tool( name="Calculator", description="Useful for when you need to answer questions about math.", @@ -110,7 +108,7 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool: ) -def _get_open_meteo_api(llm: BaseLLM) -> BaseTool: +def _get_open_meteo_api(llm: BaseLanguageModel) -> BaseTool: chain = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS) return Tool( name="Open Meteo API", @@ -119,7 +117,7 @@ def _get_open_meteo_api(llm: BaseLLM) -> BaseTool: ) -_LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = { +_LLM_TOOLS: Dict[str, Callable[[BaseLanguageModel], BaseTool]] = { "pal-math": _get_pal_math, "pal-colored-objects": _get_pal_colored_objects, "llm-math": _get_llm_math, @@ -127,7 +125,7 @@ _LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = { } -def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: +def _get_news_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool: news_api_key = kwargs["news_api_key"] chain = APIChain.from_llm_and_api_docs( llm, news_docs.NEWS_DOCS, headers={"X-Api-Key": news_api_key} @@ -139,7 +137,7 @@ def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: ) -def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: +def _get_tmdb_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool: tmdb_bearer_token = kwargs["tmdb_bearer_token"] chain = APIChain.from_llm_and_api_docs( llm, @@ -153,7 +151,7 @@ def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: ) -def _get_podcast_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: +def _get_podcast_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool: listen_api_key = kwargs["listen_api_key"] chain = APIChain.from_llm_and_api_docs( llm, @@ -238,7 +236,8 @@ def _get_scenexplain(**kwargs: Any) -> BaseTool: _EXTRA_LLM_TOOLS: Dict[ - str, Tuple[Callable[[Arg(BaseLLM, "llm"), KwArg(Any)], BaseTool], List[str]] + str, + Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]], ] = { "news-api": (_get_news_api, ["news_api_key"]), "tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]), @@ -277,7 +276,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st def load_tools( tool_names: List[str], - llm: Optional[BaseLLM] = None, + llm: Optional[BaseLanguageModel] = None, callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> List[BaseTool]: diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 50d24d7d..d7702fbc 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -8,15 +8,15 @@ import yaml from langchain.agents.agent import BaseSingleActionAgent from langchain.agents.tools import Tool from langchain.agents.types import AGENT_TO_CLASS +from langchain.base_language import BaseLanguageModel from langchain.chains.loading import load_chain, load_chain_from_config -from langchain.llms.base import BaseLLM from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" def _load_agent_from_tools( - config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any + config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any ) -> BaseSingleActionAgent: config_type = config.pop("_type") if config_type not in AGENT_TO_CLASS: @@ -29,7 +29,7 @@ def _load_agent_from_tools( def load_agent_from_config( config: dict, - llm: Optional[BaseLLM] = None, + llm: Optional[BaseLanguageModel] = None, tools: Optional[List[Tool]] = None, **kwargs: Any, ) -> BaseSingleActionAgent: diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index a1210be9..38b01d28 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -10,9 +10,9 @@ from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input +from langchain.base_language import BaseLanguageModel from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.tools.base import BaseTool @@ -141,7 +141,7 @@ class ReActChain(AgentExecutor): react = ReAct(llm=OpenAI()) """ - def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any): + def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any): """Initialize with the LLM and a docstore.""" docstore_explorer = DocstoreExplorer(docstore) tools = [ diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index a445e07e..a4065f04 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -9,7 +9,7 @@ from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputPar from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input -from langchain.llms.base import BaseLLM +from langchain.base_language import BaseLanguageModel from langchain.prompts.base import BasePromptTemplate from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper @@ -71,7 +71,7 @@ class SelfAskWithSearchChain(AgentExecutor): def __init__( self, - llm: BaseLLM, + llm: BaseLanguageModel, search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], **kwargs: Any, ): diff --git a/langchain/chains/api/openapi/chain.py b/langchain/chains/api/openapi/chain.py index 8f192271..57e9ee5f 100644 --- a/langchain/chains/api/openapi/chain.py +++ b/langchain/chains/api/openapi/chain.py @@ -7,12 +7,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast from pydantic import BaseModel, Field from requests import Response +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.api.openapi.requests_chain import APIRequesterChain from langchain.chains.api.openapi.response_chain import APIResponderChain from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.requests import Requests from langchain.tools.openapi.utils.api_models import APIOperation @@ -168,7 +168,7 @@ class OpenAPIEndpointChain(Chain, BaseModel): spec_url: str, path: str, method: str, - llm: BaseLLM, + llm: BaseLanguageModel, requests: Optional[Requests] = None, return_intermediate_steps: bool = False, **kwargs: Any @@ -188,7 +188,7 @@ class OpenAPIEndpointChain(Chain, BaseModel): def from_api_operation( cls, operation: APIOperation, - llm: BaseLLM, + llm: BaseLanguageModel, requests: Optional[Requests] = None, verbose: bool = False, return_intermediate_steps: bool = False, diff --git a/langchain/chains/api/openapi/requests_chain.py b/langchain/chains/api/openapi/requests_chain.py index 4bc8bd83..e26ce296 100644 --- a/langchain/chains/api/openapi/requests_chain.py +++ b/langchain/chains/api/openapi/requests_chain.py @@ -4,9 +4,9 @@ import json import re from typing import Any +from langchain.base_language import BaseLanguageModel from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseOutputParser @@ -38,7 +38,7 @@ class APIRequesterChain(LLMChain): @classmethod def from_llm_and_typescript( cls, - llm: BaseLLM, + llm: BaseLanguageModel, typescript_definition: str, verbose: bool = True, **kwargs: Any, diff --git a/langchain/chains/api/openapi/response_chain.py b/langchain/chains/api/openapi/response_chain.py index a1d7c5a1..325797d3 100644 --- a/langchain/chains/api/openapi/response_chain.py +++ b/langchain/chains/api/openapi/response_chain.py @@ -4,9 +4,9 @@ import json import re from typing import Any +from langchain.base_language import BaseLanguageModel from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseOutputParser @@ -36,7 +36,9 @@ class APIResponderChain(LLMChain): """Get the response parser.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain: + def from_llm( + cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any + ) -> LLMChain: """Get the response parser.""" output_parser = APIResponderOutputParser() prompt = PromptTemplate( diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index 112338ae..36cff24d 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -43,7 +43,7 @@ class GraphQAChain(Chain): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, qa_prompt: BasePromptTemplate = PROMPT, entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, **kwargs: Any, diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index 3cd6170e..7764c854 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -9,12 +9,12 @@ from typing import Any, Dict, List, Optional import numpy as np from pydantic import Extra +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain from langchain.embeddings.base import Embeddings -from langchain.llms.base import BaseLLM class HypotheticalDocumentEmbedder(Chain, Embeddings): @@ -70,7 +70,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, base_embeddings: Embeddings, prompt_key: str, **kwargs: Any, diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index ae2101e0..080b1e96 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -16,12 +17,11 @@ from langchain.chains.llm_checker.prompt import ( REVISED_ANSWER_PROMPT, ) from langchain.chains.sequential import SequentialChain -from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate def _load_question_to_checked_assertions_chain( - llm: BaseLLM, + llm: BaseLanguageModel, create_draft_answer_prompt: PromptTemplate, list_assertions_prompt: PromptTemplate, check_assertions_prompt: PromptTemplate, @@ -75,7 +75,7 @@ class LLMCheckerChain(Chain): question_to_checked_assertions_chain: SequentialChain - llm: Optional[BaseLLM] = None + llm: Optional[BaseLanguageModel] = None """[Deprecated] LLM wrapper to use.""" create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT """[Deprecated]""" @@ -158,7 +158,7 @@ class LLMCheckerChain(Chain): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT, list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT, check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, diff --git a/langchain/chains/llm_summarization_checker/base.py b/langchain/chains/llm_summarization_checker/base.py index e44a5cc7..3ae9c707 100644 --- a/langchain/chains/llm_summarization_checker/base.py +++ b/langchain/chains/llm_summarization_checker/base.py @@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain -from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate PROMPTS_DIR = Path(__file__).parent / "prompts" @@ -32,7 +32,7 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file( def _load_sequential_chain( - llm: BaseLLM, + llm: BaseLanguageModel, create_assertions_prompt: PromptTemplate, check_assertions_prompt: PromptTemplate, revised_summary_prompt: PromptTemplate, @@ -85,7 +85,7 @@ class LLMSummarizationCheckerChain(Chain): """ sequential_chain: SequentialChain - llm: Optional[BaseLLM] = None + llm: Optional[BaseLanguageModel] = None """[Deprecated] LLM wrapper to use.""" create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT @@ -180,7 +180,7 @@ class LLMSummarizationCheckerChain(Chain): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT, check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT, diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index f1b66b49..768342d1 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from pydantic import Extra +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -16,7 +17,6 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.text_splitter import TextSplitter @@ -34,7 +34,7 @@ class MapReduceChain(Chain): @classmethod def from_params( cls, - llm: BaseLLM, + llm: BaseLanguageModel, prompt: BasePromptTemplate, text_splitter: TextSplitter, callbacks: Callbacks = None, diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 452f7860..b510f89b 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT -from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI @@ -27,7 +27,7 @@ class NatBotChain(Chain): llm_chain: LLMChain objective: str """Objective that NatBot is tasked with completing.""" - llm: Optional[BaseLLM] = None + llm: Optional[BaseLanguageModel] = None """[Deprecated] LLM wrapper to use.""" input_url_key: str = "url" #: :meta private: input_browser_content_key: str = "browser_content" #: :meta private: @@ -59,7 +59,9 @@ class NatBotChain(Chain): return cls.from_llm(llm, objective, **kwargs) @classmethod - def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain: + def from_llm( + cls, llm: BaseLanguageModel, objective: str, **kwargs: Any + ) -> NatBotChain: """Load from LLM.""" llm_chain = LLMChain(llm=llm, prompt=PROMPT) return cls(llm_chain=llm_chain, objective=objective, **kwargs) diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index 43e3be65..9a7a0bf8 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -4,9 +4,9 @@ from __future__ import annotations from typing import Any, List from langchain import PromptTemplate +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT -from langchain.llms.base import BaseLLM class QAEvalChain(LLMChain): @@ -14,12 +14,12 @@ class QAEvalChain(LLMChain): @classmethod def from_llm( - cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any + cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any ) -> QAEvalChain: """Load QA Eval Chain from LLM. Args: - llm (BaseLLM): the base language model to use. + llm (BaseLanguageModel): the base language model to use. prompt (PromptTemplate): A prompt template containing the input_variables: 'input', 'answer' and 'result' that will be used as the prompt @@ -74,12 +74,15 @@ class ContextQAEvalChain(LLMChain): @classmethod def from_llm( - cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any + cls, + llm: BaseLanguageModel, + prompt: PromptTemplate = CONTEXT_PROMPT, + **kwargs: Any, ) -> ContextQAEvalChain: """Load QA Eval Chain from LLM. Args: - llm (BaseLLM): the base language model to use. + llm (BaseLanguageModel): the base language model to use. prompt (PromptTemplate): A prompt template containing the input_variables: 'query', 'context' and 'result' that will be used as the prompt @@ -120,7 +123,7 @@ class CotQAEvalChain(ContextQAEvalChain): @classmethod def from_llm( - cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any + cls, llm: BaseLanguageModel, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any ) -> CotQAEvalChain: cls._validate_input_vars(prompt) return cls(llm=llm, prompt=prompt, **kwargs) diff --git a/langchain/evaluation/qa/generate_chain.py b/langchain/evaluation/qa/generate_chain.py index 62941462..8fd01d6a 100644 --- a/langchain/evaluation/qa/generate_chain.py +++ b/langchain/evaluation/qa/generate_chain.py @@ -3,15 +3,15 @@ from __future__ import annotations from typing import Any +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.evaluation.qa.generate_prompt import PROMPT -from langchain.llms.base import BaseLLM class QAGenerateChain(LLMChain): """LLM Chain specifically for generating examples for question answering.""" @classmethod - def from_llm(cls, llm: BaseLLM, **kwargs: Any) -> QAGenerateChain: + def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> QAGenerateChain: """Load QA Generate Chain from LLM.""" return cls(llm=llm, prompt=PROMPT, **kwargs) diff --git a/langchain/example_generator.py b/langchain/example_generator.py index 7c309d05..e1ce34d8 100644 --- a/langchain/example_generator.py +++ b/langchain/example_generator.py @@ -1,8 +1,8 @@ """Utility functions for working with prompts.""" from typing import List +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example." def generate_example( - examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate + examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate ) -> str: """Return another example given a list of examples for a prompt.""" prompt = FewShotPromptTemplate( diff --git a/langchain/indexes/graph.py b/langchain/indexes/graph.py index 519e4ac4..81fabdc3 100644 --- a/langchain/indexes/graph.py +++ b/langchain/indexes/graph.py @@ -3,18 +3,18 @@ from typing import Optional, Type from pydantic import BaseModel +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.indexes.prompts.knowledge_triplet_extraction import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, ) -from langchain.llms.base import BaseLLM class GraphIndexCreator(BaseModel): """Functionality to create graph index.""" - llm: Optional[BaseLLM] = None + llm: Optional[BaseLanguageModel] = None graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph def from_text(self, text: str) -> NetworkxEntityGraph: diff --git a/langchain/indexes/vectorstore.py b/langchain/indexes/vectorstore.py index b551bd9a..f07d01a6 100644 --- a/langchain/indexes/vectorstore.py +++ b/langchain/indexes/vectorstore.py @@ -2,12 +2,12 @@ from typing import Any, List, Optional, Type from pydantic import BaseModel, Extra, Field +from langchain.base_language import BaseLanguageModel from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.document_loaders.base import BaseLoader from langchain.embeddings.base import Embeddings from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter @@ -30,7 +30,9 @@ class VectorStoreIndexWrapper(BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: + def query( + self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any + ) -> str: """Query the vectorstore.""" llm = llm or OpenAI(temperature=0) chain = RetrievalQA.from_chain_type( @@ -39,7 +41,7 @@ class VectorStoreIndexWrapper(BaseModel): return chain.run(question) def query_with_sources( - self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any + self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any ) -> dict: """Query the vectorstore and get back sources.""" llm = llm or OpenAI(temperature=0)