core[patch]: update some root_validators (#22787)

Update some of the @root_validators to be explicit pre=True or
pre=False, skip_on_failure=True for pydantic 2 compatibility.
pull/22831/head
Eugene Yurtsev 4 months ago committed by GitHub
parent 3d6e8547f9
commit 74e705250f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -68,7 +68,7 @@ class AIMessage(BaseMessage):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator()
@root_validator(pre=True)
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (

@ -59,6 +59,29 @@ class BasePromptTemplate(
tags: Optional[List[str]] = None
"""Tags to be used for tracing."""
@root_validator(pre=False, skip_on_failure=True)
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
@ -155,29 +178,6 @@ class BasePromptTemplate(
"""Create Prompt Value."""
return self.format_prompt(**kwargs)
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()

@ -309,7 +309,7 @@ class ChildTool(BaseTool):
}
return tool_input
@root_validator()
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:

@ -772,17 +772,17 @@ class VectorStoreRetriever(BaseRetriever):
arbitrary_types_allowed = True
@root_validator()
@root_validator(pre=True)
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
search_type = values["search_type"]
search_type = values.get("search_type", "similarity")
if search_type not in cls.allowed_search_types:
raise ValueError(
f"search_type of {search_type} not allowed. Valid values are: "
f"{cls.allowed_search_types}"
)
if search_type == "similarity_score_threshold":
score_threshold = values["search_kwargs"].get("score_threshold")
score_threshold = values.get("search_kwargs", {}).get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
raise ValueError(
"`score_threshold` is not specified with a float value(0~1) "

@ -23,7 +23,7 @@ class MyRunnable(RunnableSerializable[str, str]):
raise ValueError("Cannot set _my_hidden_property")
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["_my_hidden_property"] = values["my_property"]
return values

Loading…
Cancel
Save