ai21[patch]: Update @root_validators for pydantic2 migration (#25401)

Update @root_validators for pydantic 2 migration.
This commit is contained in:
Eugene Yurtsev 2024-08-15 11:26:44 -04:00 committed by GitHub
parent 6f68c8d6ab
commit a114255b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,7 +28,7 @@ class AI21Base(BaseModel):
num_retries: Optional[int] = None num_retries: Optional[int] = None
"""Maximum number of retries for API requests before giving up.""" """Maximum number of retries for API requests before giving up."""
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
api_key = convert_to_secret_str( api_key = convert_to_secret_str(
values.get("api_key") or os.getenv("AI21_API_KEY") or "" values.get("api_key") or os.getenv("AI21_API_KEY") or ""
@ -46,7 +46,13 @@ class AI21Base(BaseModel):
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC) os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
) )
values["timeout_sec"] = timeout_sec values["timeout_sec"] = timeout_sec
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
api_key = values["api_key"]
api_host = values["api_host"]
timeout_sec = values["timeout_sec"]
if values.get("client") is None: if values.get("client") is None:
values["client"] = AI21Client( values["client"] = AI21Client(
api_key=api_key.get_secret_value(), api_key=api_key.get_secret_value(),