From 4c2de2a7f28442896e06a6fc7c807cf611b0c780 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 16 Aug 2023 23:10:34 -0400 Subject: [PATCH] Adding missing types in some pydantic models (#9355) * Adding missing types in some pydantic models -- this change is required for making the code work with pydantic v2. --- libs/langchain/langchain/agents/agent.py | 4 ++-- .../embeddings_redundant_filter.py | 2 +- .../chat_message_histories/rocksetdb.py | 6 +++--- .../langchain/memory/motorhead_memory.py | 4 ++-- .../langchain/tools/google_places/tool.py | 4 ++-- libs/langchain/langchain/tools/json/tool.py | 8 ++++---- libs/langchain/langchain/tools/nuclia/tool.py | 4 ++-- .../langchain/tools/requests/tool.py | 20 +++++++++---------- .../langchain/tools/scenexplain/tool.py | 4 ++-- libs/langchain/langchain/tools/sleep/tool.py | 4 ++-- .../langchain/tools/spark_sql/tool.py | 16 +++++++-------- .../langchain/tools/sql_database/tool.py | 16 +++++++-------- libs/langchain/langchain/utilities/arxiv.py | 2 +- .../langchain/utilities/brave_search.py | 2 +- libs/langchain/langchain/utilities/pubmed.py | 12 ++++++----- .../langchain/utilities/searx_search.py | 2 +- 16 files changed, 56 insertions(+), 54 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 3935d05252..8791bce548 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -615,9 +615,9 @@ class Agent(BaseSingleActionAgent): class ExceptionTool(BaseTool): """Tool that just returns the query.""" - name = "_Exception" + name: str = "_Exception" """Name of the tool.""" - description = "Exception tool" + description: str = "Exception tool" """Description of the tool.""" def _run( diff --git a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py index f3afb2871d..c274c2a1c0 100644 --- a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py +++ b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py @@ -182,7 +182,7 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel): """By default results are re-ordered "grouping" them by cluster, if sorted is true result will be ordered by the original position from the retriever""" - remove_duplicates = False + remove_duplicates: bool = False """ By default duplicated results are skipped and replaced by the next closest vector in the cluster. If remove_duplicates is true no replacement will be done: This could dramatically reduce results when there is a lot of overlap between diff --git a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py index 9b3f35b06b..2e69073691 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py @@ -37,9 +37,9 @@ class RocksetChatMessageHistory(BaseChatMessageHistory): # These values are configured for the typical # free VI. Read more about VIs here: # https://rockset.com/docs/instances - SLEEP_INTERVAL_MS = 5 - ADD_TIMEOUT_MS = 5000 - CREATE_TIMEOUT_MS = 20000 + SLEEP_INTERVAL_MS: int = 5 + ADD_TIMEOUT_MS: int = 5000 + CREATE_TIMEOUT_MS: int = 20000 def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None: """Sleeps until meth() evaluates to true. Passes kwargs into diff --git a/libs/langchain/langchain/memory/motorhead_memory.py b/libs/langchain/langchain/memory/motorhead_memory.py index 30dfbb7cea..da9d9d82a3 100644 --- a/libs/langchain/langchain/memory/motorhead_memory.py +++ b/libs/langchain/langchain/memory/motorhead_memory.py @@ -13,8 +13,8 @@ class MotorheadMemory(BaseChatMemory): """Chat message memory backed by Motorhead service.""" url: str = MANAGED_URL - timeout = 3000 - memory_key = "history" + timeout: int = 3000 + memory_key: str = "history" session_id: str context: Optional[str] = None diff --git a/libs/langchain/langchain/tools/google_places/tool.py b/libs/langchain/langchain/tools/google_places/tool.py index c738c5eb8b..4414332581 100644 --- a/libs/langchain/langchain/tools/google_places/tool.py +++ b/libs/langchain/langchain/tools/google_places/tool.py @@ -18,8 +18,8 @@ class GooglePlacesSchema(BaseModel): class GooglePlacesTool(BaseTool): """Tool that queries the Google places API.""" - name = "google_places" - description = ( + name: str = "google_places" + description: str = ( "A wrapper around Google Places. " "Useful for when you need to validate or " "discover addressed from ambiguous text. " diff --git a/libs/langchain/langchain/tools/json/tool.py b/libs/langchain/langchain/tools/json/tool.py index ad909a1471..e504816795 100644 --- a/libs/langchain/langchain/tools/json/tool.py +++ b/libs/langchain/langchain/tools/json/tool.py @@ -84,8 +84,8 @@ class JsonSpec(BaseModel): class JsonListKeysTool(BaseTool): """Tool for listing keys in a JSON spec.""" - name = "json_spec_list_keys" - description = """ + name: str = "json_spec_list_keys" + description: str = """ Can be used to list all keys at a given path. Before calling this you should be SURE that the path to this exists. The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]). @@ -110,8 +110,8 @@ class JsonListKeysTool(BaseTool): class JsonGetValueTool(BaseTool): """Tool for getting a value in a JSON spec.""" - name = "json_spec_get_value" - description = """ + name: str = "json_spec_get_value" + description: str = """ Can be used to see value in string format at a given path. Before calling this you should be SURE that the path to this exists. The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]). diff --git a/libs/langchain/langchain/tools/nuclia/tool.py b/libs/langchain/langchain/tools/nuclia/tool.py index 0d40a43b3d..cacb164f89 100644 --- a/libs/langchain/langchain/tools/nuclia/tool.py +++ b/libs/langchain/langchain/tools/nuclia/tool.py @@ -58,8 +58,8 @@ class NUASchema(BaseModel): class NucliaUnderstandingAPI(BaseTool): """Tool to process files with the Nuclia Understanding API.""" - name = "nuclia_understanding_api" - description = ( + name: str = "nuclia_understanding_api" + description: str = ( "A wrapper around Nuclia Understanding API endpoints. " "Useful for when you need to extract text from any kind of files. " ) diff --git a/libs/langchain/langchain/tools/requests/tool.py b/libs/langchain/langchain/tools/requests/tool.py index e3abf8a0a3..db74bc055e 100644 --- a/libs/langchain/langchain/tools/requests/tool.py +++ b/libs/langchain/langchain/tools/requests/tool.py @@ -32,8 +32,8 @@ class BaseRequestsTool(BaseModel): class RequestsGetTool(BaseRequestsTool, BaseTool): """Tool for making a GET request to an API endpoint.""" - name = "requests_get" - description = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request." + name: str = "requests_get" + description: str = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request." def _run( self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None @@ -53,8 +53,8 @@ class RequestsGetTool(BaseRequestsTool, BaseTool): class RequestsPostTool(BaseRequestsTool, BaseTool): """Tool for making a POST request to an API endpoint.""" - name = "requests_post" - description = """Use this when you want to POST to a website. + name: str = "requests_post" + description: str = """Use this when you want to POST to a website. Input should be a json string with two keys: "url" and "data". The value of "url" should be a string, and the value of "data" should be a dictionary of key-value pairs you want to POST to the url. @@ -90,8 +90,8 @@ class RequestsPostTool(BaseRequestsTool, BaseTool): class RequestsPatchTool(BaseRequestsTool, BaseTool): """Tool for making a PATCH request to an API endpoint.""" - name = "requests_patch" - description = """Use this when you want to PATCH to a website. + name: str = "requests_patch" + description: str = """Use this when you want to PATCH to a website. Input should be a json string with two keys: "url" and "data". The value of "url" should be a string, and the value of "data" should be a dictionary of key-value pairs you want to PATCH to the url. @@ -127,8 +127,8 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool): class RequestsPutTool(BaseRequestsTool, BaseTool): """Tool for making a PUT request to an API endpoint.""" - name = "requests_put" - description = """Use this when you want to PUT to a website. + name: str = "requests_put" + description: str = """Use this when you want to PUT to a website. Input should be a json string with two keys: "url" and "data". The value of "url" should be a string, and the value of "data" should be a dictionary of key-value pairs you want to PUT to the url. @@ -164,8 +164,8 @@ class RequestsPutTool(BaseRequestsTool, BaseTool): class RequestsDeleteTool(BaseRequestsTool, BaseTool): """Tool for making a DELETE request to an API endpoint.""" - name = "requests_delete" - description = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request." + name: str = "requests_delete" + description: str = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request." def _run( self, diff --git a/libs/langchain/langchain/tools/scenexplain/tool.py b/libs/langchain/langchain/tools/scenexplain/tool.py index 9858e2b1e1..87a09a2e6a 100644 --- a/libs/langchain/langchain/tools/scenexplain/tool.py +++ b/libs/langchain/langchain/tools/scenexplain/tool.py @@ -17,8 +17,8 @@ class SceneXplainInput(BaseModel): class SceneXplainTool(BaseTool): """Tool that explains images.""" - name = "image_explainer" - description = ( + name: str = "image_explainer" + description: str = ( "An Image Captioning Tool: Use this tool to generate a detailed caption " "for an image. The input can be an image file of any format, and " "the output will be a text description that covers every detail of the image." diff --git a/libs/langchain/langchain/tools/sleep/tool.py b/libs/langchain/langchain/tools/sleep/tool.py index 348b9366c7..ef7bcfe8e4 100644 --- a/libs/langchain/langchain/tools/sleep/tool.py +++ b/libs/langchain/langchain/tools/sleep/tool.py @@ -21,9 +21,9 @@ class SleepInput(BaseModel): class SleepTool(BaseTool): """Tool that adds the capability to sleep.""" - name = "sleep" + name: str = "sleep" args_schema: Type[BaseModel] = SleepInput - description = "Make agent sleep for a specified number of seconds." + description: str = "Make agent sleep for a specified number of seconds." def _run( self, diff --git a/libs/langchain/langchain/tools/spark_sql/tool.py b/libs/langchain/langchain/tools/spark_sql/tool.py index 4b23d5305b..d3a1a383e7 100644 --- a/libs/langchain/langchain/tools/spark_sql/tool.py +++ b/libs/langchain/langchain/tools/spark_sql/tool.py @@ -33,8 +33,8 @@ class BaseSparkSQLTool(BaseModel): class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for querying a Spark SQL.""" - name = "query_sql_db" - description = """ + name: str = "query_sql_db" + description: str = """ Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. @@ -52,8 +52,8 @@ class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for getting metadata about a Spark SQL.""" - name = "schema_sql_db" - description = """ + name: str = "schema_sql_db" + description: str = """ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling list_tables_sql_db first! @@ -72,8 +72,8 @@ class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for getting tables names.""" - name = "list_tables_sql_db" - description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." + name: str = "list_tables_sql_db" + description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." def _run( self, @@ -91,8 +91,8 @@ class QueryCheckerTool(BaseSparkSQLTool, BaseTool): template: str = QUERY_CHECKER llm: BaseLanguageModel llm_chain: LLMChain = Field(init=False) - name = "query_checker_sql_db" - description = """ + name: str = "query_checker_sql_db" + description: str = """ Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with query_sql_db! """ diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index f9597df962..3041400d91 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -33,8 +33,8 @@ class BaseSQLDatabaseTool(BaseModel): class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for querying a SQL database.""" - name = "sql_db_query" - description = """ + name: str = "sql_db_query" + description: str = """ Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. @@ -52,8 +52,8 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for getting metadata about a SQL database.""" - name = "sql_db_schema" - description = """ + name: str = "sql_db_schema" + description: str = """ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Example Input: "table1, table2, table3" @@ -71,8 +71,8 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for getting tables names.""" - name = "sql_db_list_tables" - description = "Input is an empty string, output is a comma separated list of tables in the database." + name: str = "sql_db_list_tables" + description: str = "Input is an empty string, output is a comma separated list of tables in the database." def _run( self, @@ -90,8 +90,8 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): template: str = QUERY_CHECKER llm: BaseLanguageModel llm_chain: LLMChain = Field(init=False) - name = "sql_db_query_checker" - description = """ + name: str = "sql_db_query_checker" + description: str = """ Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with query_sql_db! """ diff --git a/libs/langchain/langchain/utilities/arxiv.py b/libs/langchain/langchain/utilities/arxiv.py index 4c9ea6300e..ff89badfe4 100644 --- a/libs/langchain/langchain/utilities/arxiv.py +++ b/libs/langchain/langchain/utilities/arxiv.py @@ -50,7 +50,7 @@ class ArxivAPIWrapper(BaseModel): arxiv_search: Any #: :meta private: arxiv_exceptions: Any # :meta private: top_k_results: int = 3 - ARXIV_MAX_QUERY_LENGTH = 300 + ARXIV_MAX_QUERY_LENGTH: int = 300 load_max_docs: int = 100 load_all_available_meta: bool = False doc_content_chars_max: Optional[int] = 4000 diff --git a/libs/langchain/langchain/utilities/brave_search.py b/libs/langchain/langchain/utilities/brave_search.py index 3932dbc08d..0eff1cbffc 100644 --- a/libs/langchain/langchain/utilities/brave_search.py +++ b/libs/langchain/langchain/utilities/brave_search.py @@ -14,7 +14,7 @@ class BraveSearchWrapper(BaseModel): """The API key to use for the Brave search engine.""" search_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the search request.""" - base_url = "https://api.search.brave.com/res/v1/web/search" + base_url: str = "https://api.search.brave.com/res/v1/web/search" """The base URL for the Brave search engine.""" def run(self, query: str) -> str: diff --git a/libs/langchain/langchain/utilities/pubmed.py b/libs/langchain/langchain/utilities/pubmed.py index 4c099e4b2e..5b4cf7db8b 100644 --- a/libs/langchain/langchain/utilities/pubmed.py +++ b/libs/langchain/langchain/utilities/pubmed.py @@ -36,14 +36,16 @@ class PubMedAPIWrapper(BaseModel): parse: Any #: :meta private: - base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?" - base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?" - max_retry = 5 - sleep_time = 0.2 + base_url_esearch: str = ( + "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?" + ) + base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?" + max_retry: int = 5 + sleep_time: float = 0.2 # Default values for the parameters top_k_results: int = 3 - MAX_QUERY_LENGTH = 300 + MAX_QUERY_LENGTH: int = 300 doc_content_chars_max: int = 2000 email: str = "your_email@example.com" diff --git a/libs/langchain/langchain/utilities/searx_search.py b/libs/langchain/langchain/utilities/searx_search.py index 3a025e6207..8dba5942ca 100644 --- a/libs/langchain/langchain/utilities/searx_search.py +++ b/libs/langchain/langchain/utilities/searx_search.py @@ -144,7 +144,7 @@ def _get_default_params() -> dict: class SearxResults(dict): """Dict like wrapper around search api results.""" - _data = "" + _data: str = "" def __init__(self, data: str): """Take a raw result from Searx and make it into a dict like object."""