community[patch]: Add linter to catch @root_validator (#24070)

- Add linter to prevent further usage of vanilla root validator
- Udpate remaining root validators
This commit is contained in:
Eugene Yurtsev 2024-07-10 10:51:03 -04:00 committed by GitHub
parent 9c6efadec3
commit c4e149d4f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 39 additions and 19 deletions

View File

@ -63,7 +63,7 @@ class FileManagementToolkit(BaseToolkit):
selected_tools: Optional[List[str]] = None
"""If provided, only provide the selected tools. Defaults to all."""
@root_validator
@root_validator(pre=True)
def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools:

View File

@ -74,7 +74,7 @@ class PlayWrightBrowserToolkit(BaseToolkit):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator
@root_validator(pre=True)
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()

View File

@ -81,7 +81,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
include_project_metadata_in_doc_metadata: bool = True
"""Set to True if you want to include the project metadata in the doc metadata."""
@root_validator
@root_validator(pre=True)
def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either local file paths are given, or remote API docset ID.

View File

@ -33,7 +33,7 @@ class DropboxLoader(BaseLoader, BaseModel):
recursive: bool = False
"""Flag to indicate whether to load files recursively from subfolders."""
@root_validator
@root_validator(pre=True)
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either folder_path or file_paths is set, but not both."""
if (

View File

@ -53,7 +53,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
file_loader_kwargs: Dict["str", Any] = {}
"""The file loader kwargs to use."""
@root_validator
@root_validator(pre=True)
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either folder_id or document_ids is set, but not both."""
if values.get("folder_id") and (

View File

@ -47,7 +47,7 @@ class GoogleApiClient:
def __post_init__(self) -> None:
self.creds = self._load_credentials()
@root_validator
@root_validator(pre=True)
def validate_channel_or_videoIds_is_set(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
@ -388,7 +388,7 @@ class GoogleApiYoutubeLoader(BaseLoader):
return build("youtube", "v3", credentials=creds)
@root_validator
@root_validator(pre=True)
def validate_channel_or_videoIds_is_set(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:

View File

@ -54,8 +54,10 @@ class AscendEmbeddings(Embeddings, BaseModel):
self.model.half()
self.encode([f"warmup {i} times" for i in range(10)])
@root_validator
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
if "model_path" not in values:
raise ValueError("model_path is required")
if not os.access(values["model_path"], os.F_OK):
raise FileNotFoundError(
f"Unabled to find valid model path in [{values['model_path']}]"

View File

@ -65,7 +65,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
)
return v
@root_validator
@root_validator(pre=True)
def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
for key in values.get("voice_models", {}).keys():
if key not in values.get("providers", []):

View File

@ -38,7 +38,7 @@ class BaseBrowserTool(BaseTool):
sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None
@root_validator
@root_validator(pre=True)
def validate_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()

View File

@ -35,7 +35,7 @@ class ExtractHyperlinksTool(BaseBrowserTool):
description: str = "Extract all hyperlinks on the current webpage"
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
@root_validator
@root_validator(pre=True)
def check_bs_import(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
try:

View File

@ -22,7 +22,7 @@ class ExtractTextTool(BaseBrowserTool):
description: str = "Extract all the text on the current webpage"
args_schema: Type[BaseModel] = BaseModel
@root_validator
@root_validator(pre=True)
def check_acheck_bs_importrgs(cls, values: dict) -> dict:
"""Check that the arguments are valid."""
try:

View File

@ -21,7 +21,7 @@ class ShellInput(BaseModel):
)
"""List of shell commands to run."""
@root_validator
@root_validator(pre=True)
def _validate_commands(cls, values: dict) -> dict:
"""Validate commands."""
# TODO: Add real validators

View File

@ -75,8 +75,9 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from langchain_core.utils import pre_init
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
from langchain_community.utilities.zapier import ZapierNLAWrapper
@ -105,7 +106,7 @@ class ZapierNLARunAction(BaseTool):
name: str = ""
description: str = ""
@root_validator
@pre_init
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
zapier_description = values["zapier_description"]
params_schema = values["params_schema"]

View File

@ -39,7 +39,7 @@ class SteamWebAPIWrapper(BaseModel):
"""Return a list of operations."""
return self.operations
@root_validator
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate api key and python package has been configured."""

View File

@ -114,7 +114,7 @@ class YouSearchAPIWrapper(BaseModel):
return values
@root_validator
@root_validator(pre=False, skip_on_failure=True)
def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict:
if values["endpoint_type"] != "news":
news_api_fields = ("search_lang", "ui_lang", "spellcheck")
@ -139,7 +139,7 @@ class YouSearchAPIWrapper(BaseModel):
)
return values
@root_validator
@root_validator(pre=False, skip_on_failure=True)
def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict:
if values["endpoint_type"] == "snippets":
warnings.warn(

View File

@ -14,7 +14,7 @@ fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
result=$(git -C "$repository_path" grep -En '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
@ -25,3 +25,20 @@ if [ -n "$result" ]; then
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi
# Forbid vanilla usage of @root_validator
# This prevents the code from using either @root_validator or @root_validator()
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo
echo "$result"
echo
echo "Please replace @root_validator or @root_validator() with either:"
echo
echo "@root_validator(pre=True) or @root_validator(pre=False, skip_on_failure=True)"
exit 1
fi