More missing type annotations (#9406)

This PR fills in more missing type annotations on pydantic models. 

It's OK if it missed some annotations, we just don't want it to get
annotations wrong at this stage.

I'll do a few more passes over the same files!
This commit is contained in:
Eugene Yurtsev 2023-08-17 12:19:50 -04:00 committed by GitHub
parent 7e63270e04
commit 77b359edf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 116 additions and 116 deletions

View File

@ -66,7 +66,7 @@ def _get_default_llm_chain_factory(
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name = "requests_get"
name: str = "requests_get"
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
@ -96,7 +96,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name = "requests_post"
name: str = "requests_post"
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
@ -125,7 +125,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name = "requests_patch"
name: str = "requests_patch"
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
@ -154,7 +154,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""A tool that sends a DELETE request and parses the response."""
name = "requests_delete"
name: str = "requests_delete"
"""The name of the tool."""
description = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""

View File

@ -18,8 +18,8 @@ class AgentTokenBufferMemory(BaseChatMemory):
"""The max number of tokens to keep in the buffer.
Once the buffer exceeds this many tokens, the oldest messages will be pruned."""
return_messages: bool = True
output_key = "output"
intermediate_steps_key = "intermediate_steps"
output_key: str = "output"
intermediate_steps_key: str = "intermediate_steps"
@property
def buffer(self) -> List[BaseMessage]:

View File

@ -11,8 +11,8 @@ from langchain.tools.base import BaseTool, Tool, tool
class InvalidTool(BaseTool):
"""Tool that is run when invalid tool name is encountered by agent."""
name = "invalid_tool"
description = "Called when tool name is invalid. Suggests valid tool names."
name: str = "invalid_tool"
description: str = "Called when tool name is invalid. Suggests valid tool names."
def _run(
self,

View File

@ -52,11 +52,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"Argilla, no doubt about it."
"""
REPO_URL = "https://github.com/argilla-io/argilla"
ISSUES_URL = f"{REPO_URL}/issues"
BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
REPO_URL: str = "https://github.com/argilla-io/argilla"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
DEFAULT_API_URL = "http://localhost:6900"
DEFAULT_API_URL: str = "http://localhost:6900"
def __init__(
self,

View File

@ -97,7 +97,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
>>> llm.predict('Tell me a story about a dog.')
"""
DEFAULT_PROJECT_NAME = "LangChain-%Y-%m-%d"
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
def __init__(
self,

View File

@ -62,7 +62,7 @@ class EvaluatorCallbackHandler(BaseTracer):
The LangSmith project name to be organize eval chain runs under.
"""
name = "evaluator_callback_handler"
name: str = "evaluator_callback_handler"
def __init__(
self,

View File

@ -19,7 +19,7 @@ class RunCollectorCallbackHandler(BaseTracer):
The ID of the example being traced. It can be either a UUID or a string.
"""
name = "run-collector_callback_handler"
name: str = "run-collector_callback_handler"
def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any

View File

@ -44,7 +44,7 @@ def elapsed(run: Any) -> str:
class FunctionCallbackHandler(BaseTracer):
"""Tracer that calls a function with a single str parameter."""
name = "function_callback_handler"
name: str = "function_callback_handler"
def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -172,7 +172,7 @@ class FunctionCallbackHandler(BaseTracer):
class ConsoleCallbackHandler(FunctionCallbackHandler):
"""Tracer that prints to the console."""
name = "console_callback_handler"
name: str = "console_callback_handler"
def __init__(self, **kwargs: Any) -> None:
super().__init__(function=print, **kwargs)

View File

@ -29,10 +29,10 @@ class ArangoGraphQAChain(Chain):
output_key: str = "result" #: :meta private:
# Specifies the maximum number of AQL Query Results to return
top_k = 10
top_k: int = 10
# Specifies the set of AQL Query Examples that promote few-shot-learning
aql_examples = ""
aql_examples: str = ""
# Specify whether to return the AQL Query in the output dictionary
return_aql_query: bool = False
@ -41,7 +41,7 @@ class ArangoGraphQAChain(Chain):
return_aql_result: bool = False
# Specify the maximum amount of AQL Generation attempts that should be made
max_aql_generation_attempts = 3
max_aql_generation_attempts: int = 3
@property
def input_keys(self) -> List[str]:

View File

@ -18,7 +18,7 @@ from langchain.utils import get_from_dict_or_env
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
SUPPORTED_ROLES = ["user", "assistant", "system"]
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:

View File

@ -30,7 +30,7 @@ class EdenAI(LLM):
for api reference check edenai documentation: http://docs.edenai.co.
"""
base_url = "https://api.edenai.run/v2"
base_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[str] = None

View File

@ -23,7 +23,7 @@ def _stream_response_to_generation_chunk(
class _OllamaCommon(BaseLanguageModel):
base_url = "http://localhost:11434"
base_url: str = "http://localhost:11434"
"""Base url the model is hosted under."""
model: str = "llama2"

View File

@ -592,7 +592,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
def count_tokens(self, *, text: str) -> int:
return len(self._encode(text))
_max_length_equal_32_bit_integer = 2**32
_max_length_equal_32_bit_integer: int = 2**32
def _encode(self, text: str) -> List[int]:
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(

View File

@ -11,8 +11,8 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivQueryRun(BaseTool):
"""Tool that searches the Arxiv API."""
name = "arxiv"
description = (
name: str = "arxiv"
description: str = (
"A wrapper around Arxiv.org "
"Useful for when you need to answer questions about Physics, Mathematics, "
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "

View File

@ -23,8 +23,8 @@ class AzureCogsFormRecognizerTool(BaseTool):
azure_cogs_endpoint: str = "" #: :meta private:
doc_analysis_client: Any #: :meta private:
name = "azure_cognitive_services_form_recognizer"
description = (
name: str = "azure_cognitive_services_form_recognizer"
description: str = (
"A wrapper around Azure Cognitive Services Form Recognizer. "
"Useful for when you need to "
"extract text, tables, and key-value pairs from documents. "

View File

@ -24,8 +24,8 @@ class AzureCogsImageAnalysisTool(BaseTool):
vision_service: Any #: :meta private:
analysis_options: Any #: :meta private:
name = "azure_cognitive_services_image_analysis"
description = (
name: str = "azure_cognitive_services_image_analysis"
description: str = (
"A wrapper around Azure Cognitive Services Image Analysis. "
"Useful for when you need to analyze images. "
"Input should be a url to an image."

View File

@ -28,8 +28,8 @@ class AzureCogsSpeech2TextTool(BaseTool):
speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private:
name = "azure_cognitive_services_speech2text"
description = (
name: str = "azure_cognitive_services_speech2text"
description: str = (
"A wrapper around Azure Cognitive Services Speech2Text. "
"Useful for when you need to transcribe audio to text. "
"Input should be a url to an audio file."

View File

@ -24,8 +24,8 @@ class AzureCogsText2SpeechTool(BaseTool):
speech_language: str = "en-US" #: :meta private:
speech_config: Any #: :meta private:
name = "azure_cognitive_services_text2speech"
description = (
name: str = "azure_cognitive_services_text2speech"
description: str = (
"A wrapper around Azure Cognitive Services Text2Speech. "
"Useful for when you need to convert text to speech. "
)

View File

@ -55,8 +55,8 @@ def _get_filtered_args(
class _SchemaConfig:
"""Configuration for the pydantic model."""
extra = Extra.forbid
arbitrary_types_allowed = True
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
def create_schema_from_function(

View File

@ -10,8 +10,8 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper
class BingSearchRun(BaseTool):
"""Tool that queries the Bing search API."""
name = "bing_search"
description = (
name: str = "bing_search"
description: str = (
"A wrapper around Bing Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
@ -30,8 +30,8 @@ class BingSearchRun(BaseTool):
class BingSearchResults(BaseTool):
"""Tool that queries the Bing Search API and gets back json."""
name = "Bing Search Results JSON"
description = (
name: str = "Bing Search Results JSON"
description: str = (
"A wrapper around Bing Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"

View File

@ -10,8 +10,8 @@ from langchain.utilities.brave_search import BraveSearchWrapper
class BraveSearch(BaseTool):
"""Tool that queries the BraveSearch."""
name = "brave_search"
description = (
name: str = "brave_search"
description: str = (
"a search engine. "
"useful for when you need to answer questions about current events."
" input should be a search query."

View File

@ -14,8 +14,8 @@ from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper
class DataForSeoAPISearchRun(BaseTool):
"""Tool that queries the DataForSeo Google search API."""
name = "dataforseo_api_search"
description = (
name: str = "dataforseo_api_search"
description: str = (
"A robust Google Search API provided by DataForSeo."
"This tool is handy when you need information about trending topics "
"or current events."
@ -43,8 +43,8 @@ class DataForSeoAPISearchResults(BaseTool):
"""Tool that queries the DataForSeo Google Search API
and get back json."""
name = "DataForSeo Results JSON"
description = (
name: str = "DataForSeo Results JSON"
description: str = (
"A comprehensive Google Search API provided by DataForSeo."
"This tool is useful for obtaining real-time data on current events "
"or popular searches."

View File

@ -12,8 +12,8 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
class DuckDuckGoSearchRun(BaseTool):
"""Tool that queries the DuckDuckGo search API."""
name = "duckduckgo_search"
description = (
name: str = "duckduckgo_search"
description: str = (
"A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
@ -34,8 +34,8 @@ class DuckDuckGoSearchRun(BaseTool):
class DuckDuckGoSearchResults(BaseTool):
"""Tool that queries the DuckDuckGo search API and gets back json."""
name = "DuckDuckGo Results JSON"
description = (
name: str = "DuckDuckGo Results JSON"
description: str = (
"A wrapper around Duck Duck Go Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"

View File

@ -20,8 +20,8 @@ class GitHubAction(BaseTool):
api_wrapper: GitHubAPIWrapper = Field(default_factory=GitHubAPIWrapper)
mode: str
name = ""
description = ""
name: str = ""
description: str = ""
def _run(
self,

View File

@ -10,8 +10,8 @@ from langchain.utilities.golden_query import GoldenQueryAPIWrapper
class GoldenQueryRun(BaseTool):
"""Tool that adds the capability to query using the Golden API and get back JSON."""
name = "Golden Query"
description = (
name: str = "Golden Query"
description: str = (
"A wrapper around Golden Query API."
" Useful for getting entities that match"
" a natural language query from Golden's Knowledge Base."

View File

@ -10,8 +10,8 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper
class GoogleSearchRun(BaseTool):
"""Tool that queries the Google search API."""
name = "google_search"
description = (
name: str = "google_search"
description: str = (
"A wrapper around Google Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
@ -30,8 +30,8 @@ class GoogleSearchRun(BaseTool):
class GoogleSearchResults(BaseTool):
"""Tool that queries the Google Search API and gets back json."""
name = "Google Search Results JSON"
description = (
name: str = "Google Search Results JSON"
description: str = (
"A wrapper around Google Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"

View File

@ -14,8 +14,8 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
class GoogleSerperRun(BaseTool):
"""Tool that queries the Serper.dev Google search API."""
name = "google_serper"
description = (
name: str = "google_serper"
description: str = (
"A low-cost Google Search API."
"Useful for when you need to answer questions about current events."
"Input should be a search query."
@ -43,8 +43,8 @@ class GoogleSerperResults(BaseTool):
"""Tool that queries the Serper.dev Google Search API
and get back json."""
name = "google_serrper_results_json"
description = (
name: str = "google_serrper_results_json"
description: str = (
"A low-cost Google Search API."
"Useful for when you need to answer questions about current events."
"Input should be a search query. Output is a JSON object of the query results"

View File

@ -11,8 +11,8 @@ class BaseGraphQLTool(BaseTool):
graphql_wrapper: GraphQLAPIWrapper
name = "query_graphql"
description = """\
name: str = "query_graphql"
description: str = """\
Input to this tool is a detailed and correct GraphQL query, output is a result from the API.
If the query is not correct, an error message will be returned.
If an error is returned with 'Bad request' in it, rewrite the query and try again.

View File

@ -15,8 +15,8 @@ def _print_func(text: str) -> None:
class HumanInputRun(BaseTool):
"""Tool that asks user for input."""
name = "human"
description = (
name: str = "human"
description: str = (
"You can ask a human for guidance when you think you "
"got stuck or you are not sure what to do next. "
"The input should be a question for the human."

View File

@ -41,8 +41,8 @@ class JiraAction(BaseTool):
api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper)
mode: str
name = ""
description = ""
name: str = ""
description: str = ""
def _run(
self,

View File

@ -13,8 +13,8 @@ from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
class MetaphorSearchResults(BaseTool):
"""Tool that queries the Metaphor Search API and gets back json."""
name = "metaphor_search_results_json"
description = (
name: str = "metaphor_search_results_json"
description: str = (
"A wrapper around Metaphor Search. "
"Input should be a Metaphor-optimized query. "
"Output is a JSON array of the query results"

View File

@ -15,8 +15,8 @@ class OpenWeatherMapQueryRun(BaseTool):
default_factory=OpenWeatherMapAPIWrapper
)
name = "OpenWeatherMap"
description = (
name: str = "OpenWeatherMap"
description: str = (
"A wrapper around OpenWeatherMap API. "
"Useful for fetching current weather information for a specified location. "
"Input should be a location string (e.g. London,GB)."

View File

@ -24,8 +24,8 @@ logger = logging.getLogger(__name__)
class QueryPowerBITool(BaseTool):
"""Tool for querying a Power BI Dataset."""
name = "query_powerbi"
description = """
name: str = "query_powerbi"
description: str = """
Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification.
Example Input: "How many rows are in table1?"
@ -217,8 +217,8 @@ class QueryPowerBITool(BaseTool):
class InfoPowerBITool(BaseTool):
"""Tool for getting metadata about a PowerBI Dataset."""
name = "schema_powerbi"
description = """
name: str = "schema_powerbi"
description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_powerbi first!
@ -250,8 +250,8 @@ class InfoPowerBITool(BaseTool):
class ListPowerBITool(BaseTool):
"""Tool for getting tables names."""
name = "list_tables_powerbi"
description = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301
name: str = "list_tables_powerbi"
description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301
powerbi: PowerBIDataset = Field(exclude=True)
class Config:

View File

@ -9,8 +9,8 @@ from langchain.utilities.pubmed import PubMedAPIWrapper
class PubmedQueryRun(BaseTool):
"""Tool that searches the PubMed API."""
name = "PubMed"
description = (
name: str = "PubMed"
description: str = (
"A wrapper around PubMed. "
"Useful for when you need to answer questions about medicine, health, "
"and biomedical topics "

View File

@ -42,8 +42,8 @@ def sanitize_input(query: str) -> str:
class PythonREPLTool(BaseTool):
"""A tool for running python code in a REPL."""
name = "Python_REPL"
description = (
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands. "
"Input should be a valid python command. "
"If you want to see the output of a value, you should print it out "
@ -80,8 +80,8 @@ class PythonREPLTool(BaseTool):
class PythonAstREPLTool(BaseTool):
"""A tool for running python code in a REPL."""
name = "python_repl_ast"
description = (
name: str = "python_repl_ast"
description: str = (
"A Python shell. Use this to execute python commands. "
"Input should be a valid python command. "
"When using this tool, sometimes output is abbreviated - "

View File

@ -13,8 +13,8 @@ from langchain.utilities.searx_search import SearxSearchWrapper
class SearxSearchRun(BaseTool):
"""Tool that queries a Searx instance."""
name = "searx_search"
description = (
name: str = "searx_search"
description: str = (
"A meta search engine."
"Useful for when you need to answer questions about current events."
"Input should be a search query."
@ -42,8 +42,8 @@ class SearxSearchRun(BaseTool):
class SearxSearchResults(BaseTool):
"""Tool that queries a Searx instance and gets back json."""
name = "Searx Search Results"
description = (
name: str = "Searx Search Results"
description: str = (
"A meta search engine."
"Useful for when you need to answer questions about current events."
"Input should be a search query. Output is a JSON array of the query results"

View File

@ -48,8 +48,8 @@ class SteamshipImageGenerationTool(BaseTool):
steamship: Steamship
return_urls: Optional[bool] = False
name = "GenerateImage"
description = (
name: str = "GenerateImage"
description: str = (
"Useful for when you need to generate an image."
"Input: A detailed text-2-image prompt describing an image"
"Output: the UUID of a generated image"

View File

@ -10,8 +10,8 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
class WikipediaQueryRun(BaseTool):
"""Tool that searches the Wikipedia API."""
name = "Wikipedia"
description = (
name: str = "Wikipedia"
description: str = (
"A wrapper around Wikipedia. "
"Useful for when you need to answer general questions about "
"people, places, companies, facts, historical events, or other subjects. "

View File

@ -10,8 +10,8 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
class WolframAlphaQueryRun(BaseTool):
"""Tool that queries using the Wolfram Alpha SDK."""
name = "wolfram_alpha"
description = (
name: str = "wolfram_alpha"
description: str = (
"A wrapper around Wolfram Alpha. "
"Useful for when you need to answer questions about Math, "
"Science, Technology, Culture, Society and Everyday Life. "

View File

@ -18,8 +18,8 @@ from langchain.tools import BaseTool
class YouTubeSearchTool(BaseTool):
"""Tool that queries YouTube."""
name = "youtube_search"
description = (
name: str = "youtube_search"
description: str = (
"search for youtube videos associated with a person. "
"the input to this tool should be a comma separated list, "
"the first part contains a person name and the second a "

View File

@ -108,8 +108,8 @@ class ZapierNLARunAction(BaseTool):
base_prompt: str = BASE_ZAPIER_TOOL_PROMPT
zapier_description: str
params_schema: Dict[str, str] = Field(default_factory=dict)
name = ""
description = ""
name: str = ""
description: str = ""
@root_validator
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@ -167,8 +167,8 @@ class ZapierNLAListActions(BaseTool):
"""
name = "ZapierNLA_list_actions"
description = BASE_ZAPIER_TOOL_PROMPT + (
name: str = "ZapierNLA_list_actions"
description: str = BASE_ZAPIER_TOOL_PROMPT + (
"This tool returns a list of the user's exposed actions."
)
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)

View File

@ -43,9 +43,9 @@ class _MockSchema(BaseModel):
class _MockStructuredTool(BaseTool):
name = "structured_api"
name: str = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
@ -69,10 +69,10 @@ def test_misannotated_base_tool_raises_error() -> None:
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
name: str = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
@ -87,9 +87,9 @@ def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
name: str = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
@ -104,9 +104,9 @@ def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
name: str = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
@ -154,8 +154,8 @@ def test_decorated_function_schema_equivalent() -> None:
def test_args_kwargs_filtered() -> None:
class _SingleArgToolWithKwargs(BaseTool):
name = "single_arg_tool"
description = "A single arged tool with kwargs"
name: str = "single_arg_tool"
description: str = "A single arged tool with kwargs"
def _run(
self,
@ -177,8 +177,8 @@ def test_args_kwargs_filtered() -> None:
assert tool.is_single_input
class _VarArgToolWithKwargs(BaseTool):
name = "single_arg_tool"
description = "A single arged tool with kwargs"
name: str = "single_arg_tool"
description: str = "A single arged tool with kwargs"
def _run(
self,
@ -269,8 +269,8 @@ def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool."""
class _MockSimpleTool(BaseTool):
name = "simple_tool"
description = "A Simple Tool"
name: str = "simple_tool"
description: str = "A Simple Tool"
def _run(self, tool_input: str) -> str:
return f"{tool_input}"
@ -593,8 +593,8 @@ async def test_create_async_tool() -> None:
class _FakeExceptionTool(BaseTool):
name = "exception"
description = "an exception-throwing tool"
name: str = "exception"
description: str = "an exception-throwing tool"
exception: Exception = ToolException()
def _run(self) -> str: