From 64501329ab17edb8a30f63ed829693912ec45108 Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Tue, 25 Apr 2023 23:30:49 -0700 Subject: [PATCH] [simple] updated annotation in load_tools.py (#3544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - added a few missing annotation for complex local variables. - auto formatted. - I also went through all other files in agent directory. no seeing any other missing piece. (there are several prompt strings not annotated, but I think it’s trivial. Also adding annotation will make it harder to read in terms of indents.) Anyway, I think this is the last PR in agent/annotation. --- langchain/agents/load_tools.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index da01bd70..de057275 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -2,7 +2,7 @@ """Load tools.""" import warnings from typing import Any, Dict, List, Optional, Callable, Tuple -from mypy_extensions import KwArg +from mypy_extensions import Arg, KwArg from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager @@ -74,7 +74,7 @@ def _get_terminal() -> BaseTool: ) -_BASE_TOOLS = { +_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { "python_repl": _get_python_repl, "requests": _get_tools_requests_get, # preserved for backwards compatability "requests_get": _get_tools_requests_get, @@ -120,7 +120,7 @@ def _get_open_meteo_api(llm: BaseLLM) -> BaseTool: ) -_LLM_TOOLS = { +_LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = { "pal-math": _get_pal_math, "pal-colored-objects": _get_pal_colored_objects, "llm-math": _get_llm_math, @@ -226,7 +226,9 @@ def _get_human_tool(**kwargs: Any) -> BaseTool: return HumanInputRun(**kwargs) -_EXTRA_LLM_TOOLS = { +_EXTRA_LLM_TOOLS: Dict[ + str, Tuple[Callable[[Arg(BaseLLM, "llm"), KwArg(Any)], BaseTool], List[str]] +] = { "news-api": (_get_news_api, ["news_api_key"]), "tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]), "podcast-api": (_get_podcast_api, ["listen_api_key"]),