diff --git a/libs/partners/ai21/langchain_ai21/ai21_base.py b/libs/partners/ai21/langchain_ai21/ai21_base.py index 5b8fcca0f1..c681b88721 100644 --- a/libs/partners/ai21/langchain_ai21/ai21_base.py +++ b/libs/partners/ai21/langchain_ai21/ai21_base.py @@ -28,7 +28,7 @@ class AI21Base(BaseModel): num_retries: Optional[int] = None """Maximum number of retries for API requests before giving up.""" - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: api_key = convert_to_secret_str( values.get("api_key") or os.getenv("AI21_API_KEY") or "" @@ -46,7 +46,13 @@ class AI21Base(BaseModel): os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC) ) values["timeout_sec"] = timeout_sec + return values + @root_validator(pre=False, skip_on_failure=True) + def post_init(cls, values: Dict) -> Dict: + api_key = values["api_key"] + api_host = values["api_host"] + timeout_sec = values["timeout_sec"] if values.get("client") is None: values["client"] = AI21Client( api_key=api_key.get_secret_value(),