core[patch]: Update remaining root_validators (#22829)

This PR updates the remaining root_validators in core to either be explicit pre-init or post-init validators.
erick/core-loosen-packaging-lib-version
Eugene Yurtsev 4 months ago committed by GitHub
parent 265e650e64
commit 5dbbdcbf8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -120,7 +120,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
@root_validator()
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:

@ -232,7 +232,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
arbitrary_types_allowed = True
@root_validator()
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:

@ -149,7 +149,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
values["tool_calls"] = []

@ -121,7 +121,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:

@ -64,7 +64,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:

@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate):
prompt.format(foo="bar")
# Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
prompt = PromptTemplate(template="Say {foo}")
"""
@property
@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def pre_init_validation(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template
# is not provided.
return values
# Set some default values based on the field defaults
values.setdefault("template_format", "f-string")
values.setdefault("partial_variables", {})
if values.get("validate_template"):
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
if "input_variables" not in values:
raise ValueError(
"Input variables must be provided to validate the template."
)
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
if values["template_format"]:
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache":
return super().get_input_schema(config)
@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate):
kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod
def from_examples(
cls,

Loading…
Cancel
Save