core[patch]: Update remaining root_validators (#22829)

This PR updates the remaining root_validators in core to either be explicit pre-init or post-init validators.
erick/core-loosen-packaging-lib-version
Eugene Yurtsev 4 months ago committed by GitHub
parent 265e650e64
commit 5dbbdcbf8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -120,7 +120,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace.""" """[DEPRECATED] Callback manager to add to the run trace."""
@root_validator() @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:

@ -232,7 +232,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@root_validator() @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:

@ -149,7 +149,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"invalid_tool_calls": self.invalid_tool_calls, "invalid_tool_calls": self.invalid_tool_calls,
} }
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def init_tool_calls(cls, values: dict) -> dict: def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]: if not values["tool_call_chunks"]:
values["tool_calls"] = [] values["tool_calls"] = []

@ -121,7 +121,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
template_format: Literal["f-string", "jinja2"] = "f-string" template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent.""" """Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]: if values["validate_template"]:

@ -64,7 +64,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent.""" """Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]: if values["validate_template"]:

@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate):
prompt.format(foo="bar") prompt.format(foo="bar")
# Instantiation using initializer # Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") prompt = PromptTemplate(template="Say {foo}")
""" """
@property @property
@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate):
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@root_validator(pre=True)
def pre_init_validation(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template
# is not provided.
return values
# Set some default values based on the field defaults
values.setdefault("template_format", "f-string")
values.setdefault("partial_variables", {})
if values.get("validate_template"):
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
if "input_variables" not in values:
raise ValueError(
"Input variables must be provided to validate the template."
)
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
if values["template_format"]:
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache": if self.template_format != "mustache":
return super().get_input_schema(config) return super().get_input_schema(config)
@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate):
kwargs = self._merge_partial_and_user_variables(**kwargs) kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod @classmethod
def from_examples( def from_examples(
cls, cls,

Loading…
Cancel
Save