From a114255b822f5dd41dd1f9fb48aa72e1969c89e9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 11:26:44 -0400 Subject: [PATCH] ai21[patch]: Update @root_validators for pydantic2 migration (#25401) Update @root_validators for pydantic 2 migration. --- libs/partners/ai21/langchain_ai21/ai21_base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/partners/ai21/langchain_ai21/ai21_base.py b/libs/partners/ai21/langchain_ai21/ai21_base.py index 5b8fcca0f1..c681b88721 100644 --- a/libs/partners/ai21/langchain_ai21/ai21_base.py +++ b/libs/partners/ai21/langchain_ai21/ai21_base.py @@ -28,7 +28,7 @@ class AI21Base(BaseModel): num_retries: Optional[int] = None """Maximum number of retries for API requests before giving up.""" - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: api_key = convert_to_secret_str( values.get("api_key") or os.getenv("AI21_API_KEY") or "" @@ -46,7 +46,13 @@ class AI21Base(BaseModel): os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC) ) values["timeout_sec"] = timeout_sec + return values + @root_validator(pre=False, skip_on_failure=True) + def post_init(cls, values: Dict) -> Dict: + api_key = values["api_key"] + api_host = values["api_host"] + timeout_sec = values["timeout_sec"] if values.get("client") is None: values["client"] = AI21Client( api_key=api_key.get_secret_value(),