Adding missing types in some pydantic models (#9355)

* Adding missing types in some pydantic models -- this change is
required for making the code work with pydantic v2.
pull/9358/head
Eugene Yurtsev 1 year ago committed by GitHub
parent 1c089cadd7
commit 4c2de2a7f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -615,9 +615,9 @@ class Agent(BaseSingleActionAgent):
class ExceptionTool(BaseTool):
"""Tool that just returns the query."""
name = "_Exception"
name: str = "_Exception"
"""Name of the tool."""
description = "Exception tool"
description: str = "Exception tool"
"""Description of the tool."""
def _run(

@ -182,7 +182,7 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
"""By default results are re-ordered "grouping" them by cluster, if sorted is true
result will be ordered by the original position from the retriever"""
remove_duplicates = False
remove_duplicates: bool = False
""" By default duplicated results are skipped and replaced by the next closest
vector in the cluster. If remove_duplicates is true no replacement will be done:
This could dramatically reduce results when there is a lot of overlap between

@ -37,9 +37,9 @@ class RocksetChatMessageHistory(BaseChatMessageHistory):
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS = 5
ADD_TIMEOUT_MS = 5000
CREATE_TIMEOUT_MS = 20000
SLEEP_INTERVAL_MS: int = 5
ADD_TIMEOUT_MS: int = 5000
CREATE_TIMEOUT_MS: int = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into

@ -13,8 +13,8 @@ class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL
timeout = 3000
memory_key = "history"
timeout: int = 3000
memory_key: str = "history"
session_id: str
context: Optional[str] = None

@ -18,8 +18,8 @@ class GooglePlacesSchema(BaseModel):
class GooglePlacesTool(BaseTool):
"""Tool that queries the Google places API."""
name = "google_places"
description = (
name: str = "google_places"
description: str = (
"A wrapper around Google Places. "
"Useful for when you need to validate or "
"discover addressed from ambiguous text. "

@ -84,8 +84,8 @@ class JsonSpec(BaseModel):
class JsonListKeysTool(BaseTool):
"""Tool for listing keys in a JSON spec."""
name = "json_spec_list_keys"
description = """
name: str = "json_spec_list_keys"
description: str = """
Can be used to list all keys at a given path.
Before calling this you should be SURE that the path to this exists.
The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).
@ -110,8 +110,8 @@ class JsonListKeysTool(BaseTool):
class JsonGetValueTool(BaseTool):
"""Tool for getting a value in a JSON spec."""
name = "json_spec_get_value"
description = """
name: str = "json_spec_get_value"
description: str = """
Can be used to see value in string format at a given path.
Before calling this you should be SURE that the path to this exists.
The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).

@ -58,8 +58,8 @@ class NUASchema(BaseModel):
class NucliaUnderstandingAPI(BaseTool):
"""Tool to process files with the Nuclia Understanding API."""
name = "nuclia_understanding_api"
description = (
name: str = "nuclia_understanding_api"
description: str = (
"A wrapper around Nuclia Understanding API endpoints. "
"Useful for when you need to extract text from any kind of files. "
)

@ -32,8 +32,8 @@ class BaseRequestsTool(BaseModel):
class RequestsGetTool(BaseRequestsTool, BaseTool):
"""Tool for making a GET request to an API endpoint."""
name = "requests_get"
description = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request."
name: str = "requests_get"
description: str = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request."
def _run(
self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
@ -53,8 +53,8 @@ class RequestsGetTool(BaseRequestsTool, BaseTool):
class RequestsPostTool(BaseRequestsTool, BaseTool):
"""Tool for making a POST request to an API endpoint."""
name = "requests_post"
description = """Use this when you want to POST to a website.
name: str = "requests_post"
description: str = """Use this when you want to POST to a website.
Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to POST to the url.
@ -90,8 +90,8 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
class RequestsPatchTool(BaseRequestsTool, BaseTool):
"""Tool for making a PATCH request to an API endpoint."""
name = "requests_patch"
description = """Use this when you want to PATCH to a website.
name: str = "requests_patch"
description: str = """Use this when you want to PATCH to a website.
Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to PATCH to the url.
@ -127,8 +127,8 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
class RequestsPutTool(BaseRequestsTool, BaseTool):
"""Tool for making a PUT request to an API endpoint."""
name = "requests_put"
description = """Use this when you want to PUT to a website.
name: str = "requests_put"
description: str = """Use this when you want to PUT to a website.
Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to PUT to the url.
@ -164,8 +164,8 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
class RequestsDeleteTool(BaseRequestsTool, BaseTool):
"""Tool for making a DELETE request to an API endpoint."""
name = "requests_delete"
description = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request."
name: str = "requests_delete"
description: str = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request."
def _run(
self,

@ -17,8 +17,8 @@ class SceneXplainInput(BaseModel):
class SceneXplainTool(BaseTool):
"""Tool that explains images."""
name = "image_explainer"
description = (
name: str = "image_explainer"
description: str = (
"An Image Captioning Tool: Use this tool to generate a detailed caption "
"for an image. The input can be an image file of any format, and "
"the output will be a text description that covers every detail of the image."

@ -21,9 +21,9 @@ class SleepInput(BaseModel):
class SleepTool(BaseTool):
"""Tool that adds the capability to sleep."""
name = "sleep"
name: str = "sleep"
args_schema: Type[BaseModel] = SleepInput
description = "Make agent sleep for a specified number of seconds."
description: str = "Make agent sleep for a specified number of seconds."
def _run(
self,

@ -33,8 +33,8 @@ class BaseSparkSQLTool(BaseModel):
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL."""
name = "query_sql_db"
description = """
name: str = "query_sql_db"
description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
@ -52,8 +52,8 @@ class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting metadata about a Spark SQL."""
name = "schema_sql_db"
description = """
name: str = "schema_sql_db"
description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
@ -72,8 +72,8 @@ class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting tables names."""
name = "list_tables_sql_db"
description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
name: str = "list_tables_sql_db"
description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
def _run(
self,
@ -91,8 +91,8 @@ class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "query_checker_sql_db"
description = """
name: str = "query_checker_sql_db"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
"""

@ -33,8 +33,8 @@ class BaseSQLDatabaseTool(BaseModel):
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name = "sql_db_query"
description = """
name: str = "sql_db_query"
description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
@ -52,8 +52,8 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name = "sql_db_schema"
description = """
name: str = "sql_db_schema"
description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Example Input: "table1, table2, table3"
@ -71,8 +71,8 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database."
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
def _run(
self,
@ -90,8 +90,8 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = """
name: str = "sql_db_query_checker"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
"""

@ -50,7 +50,7 @@ class ArxivAPIWrapper(BaseModel):
arxiv_search: Any #: :meta private:
arxiv_exceptions: Any # :meta private:
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300
ARXIV_MAX_QUERY_LENGTH: int = 300
load_max_docs: int = 100
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000

@ -14,7 +14,7 @@ class BraveSearchWrapper(BaseModel):
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search"
base_url: str = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str:

@ -36,14 +36,16 @@ class PubMedAPIWrapper(BaseModel):
parse: Any #: :meta private:
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry = 5
sleep_time = 0.2
base_url_esearch: str = (
"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
)
base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry: int = 5
sleep_time: float = 0.2
# Default values for the parameters
top_k_results: int = 3
MAX_QUERY_LENGTH = 300
MAX_QUERY_LENGTH: int = 300
doc_content_chars_max: int = 2000
email: str = "your_email@example.com"

@ -144,7 +144,7 @@ def _get_default_params() -> dict:
class SearxResults(dict):
"""Dict like wrapper around search api results."""
_data = ""
_data: str = ""
def __init__(self, data: str):
"""Take a raw result from Searx and make it into a dict like object."""

Loading…
Cancel
Save