mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community[patch]: Add linter to catch @root_validator (#24070)
- Add linter to prevent further usage of vanilla root validator - Udpate remaining root validators
This commit is contained in:
parent
9c6efadec3
commit
c4e149d4f1
@ -63,7 +63,7 @@ class FileManagementToolkit(BaseToolkit):
|
||||
selected_tools: Optional[List[str]] = None
|
||||
"""If provided, only provide the selected tools. Defaults to all."""
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_tools(cls, values: dict) -> dict:
|
||||
selected_tools = values.get("selected_tools") or []
|
||||
for tool_name in selected_tools:
|
||||
|
@ -74,7 +74,7 @@ class PlayWrightBrowserToolkit(BaseToolkit):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
lazy_import_playwright_browsers()
|
||||
|
@ -81,7 +81,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
||||
include_project_metadata_in_doc_metadata: bool = True
|
||||
"""Set to True if you want to include the project metadata in the doc metadata."""
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that either local file paths are given, or remote API docset ID.
|
||||
|
||||
|
@ -33,7 +33,7 @@ class DropboxLoader(BaseLoader, BaseModel):
|
||||
recursive: bool = False
|
||||
"""Flag to indicate whether to load files recursively from subfolders."""
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that either folder_path or file_paths is set, but not both."""
|
||||
if (
|
||||
|
@ -53,7 +53,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
file_loader_kwargs: Dict["str", Any] = {}
|
||||
"""The file loader kwargs to use."""
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that either folder_id or document_ids is set, but not both."""
|
||||
if values.get("folder_id") and (
|
||||
|
@ -47,7 +47,7 @@ class GoogleApiClient:
|
||||
def __post_init__(self) -> None:
|
||||
self.creds = self._load_credentials()
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_channel_or_videoIds_is_set(
|
||||
cls, values: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
@ -388,7 +388,7 @@ class GoogleApiYoutubeLoader(BaseLoader):
|
||||
|
||||
return build("youtube", "v3", credentials=creds)
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_channel_or_videoIds_is_set(
|
||||
cls, values: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -54,8 +54,10 @@ class AscendEmbeddings(Embeddings, BaseModel):
|
||||
self.model.half()
|
||||
self.encode([f"warmup {i} times" for i in range(10)])
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
if "model_path" not in values:
|
||||
raise ValueError("model_path is required")
|
||||
if not os.access(values["model_path"], os.F_OK):
|
||||
raise FileNotFoundError(
|
||||
f"Unabled to find valid model path in [{values['model_path']}]"
|
||||
|
@ -65,7 +65,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
|
||||
)
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
|
||||
for key in values.get("voice_models", {}).keys():
|
||||
if key not in values.get("providers", []):
|
||||
|
@ -38,7 +38,7 @@ class BaseBrowserTool(BaseTool):
|
||||
sync_browser: Optional["SyncBrowser"] = None
|
||||
async_browser: Optional["AsyncBrowser"] = None
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_browser_provided(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
lazy_import_playwright_browsers()
|
||||
|
@ -35,7 +35,7 @@ class ExtractHyperlinksTool(BaseBrowserTool):
|
||||
description: str = "Extract all hyperlinks on the current webpage"
|
||||
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def check_bs_import(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
try:
|
||||
|
@ -22,7 +22,7 @@ class ExtractTextTool(BaseBrowserTool):
|
||||
description: str = "Extract all the text on the current webpage"
|
||||
args_schema: Type[BaseModel] = BaseModel
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def check_acheck_bs_importrgs(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
try:
|
||||
|
@ -21,7 +21,7 @@ class ShellInput(BaseModel):
|
||||
)
|
||||
"""List of shell commands to run."""
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def _validate_commands(cls, values: dict) -> dict:
|
||||
"""Validate commands."""
|
||||
# TODO: Add real validators
|
||||
|
@ -75,8 +75,9 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
|
||||
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
||||
@ -105,7 +106,7 @@ class ZapierNLARunAction(BaseTool):
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
||||
@root_validator
|
||||
@pre_init
|
||||
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
zapier_description = values["zapier_description"]
|
||||
params_schema = values["params_schema"]
|
||||
|
@ -39,7 +39,7 @@ class SteamWebAPIWrapper(BaseModel):
|
||||
"""Return a list of operations."""
|
||||
return self.operations
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate api key and python package has been configured."""
|
||||
|
||||
|
@ -114,7 +114,7 @@ class YouSearchAPIWrapper(BaseModel):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict:
|
||||
if values["endpoint_type"] != "news":
|
||||
news_api_fields = ("search_lang", "ui_lang", "spellcheck")
|
||||
@ -139,7 +139,7 @@ class YouSearchAPIWrapper(BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict:
|
||||
if values["endpoint_type"] == "snippets":
|
||||
warnings.warn(
|
||||
|
@ -14,7 +14,7 @@ fi
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
result=$(git -C "$repository_path" grep -En '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
@ -25,3 +25,20 @@ if [ -n "$result" ]; then
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Forbid vanilla usage of @root_validator
|
||||
# This prevents the code from using either @root_validator or @root_validator()
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo
|
||||
echo "$result"
|
||||
echo
|
||||
echo "Please replace @root_validator or @root_validator() with either:"
|
||||
echo
|
||||
echo "@root_validator(pre=True) or @root_validator(pre=False, skip_on_failure=True)"
|
||||
exit 1
|
||||
fi
|
||||
|
Loading…
Reference in New Issue
Block a user