|
|
@ -28,7 +28,7 @@ class AI21Base(BaseModel):
|
|
|
|
num_retries: Optional[int] = None
|
|
|
|
num_retries: Optional[int] = None
|
|
|
|
"""Maximum number of retries for API requests before giving up."""
|
|
|
|
"""Maximum number of retries for API requests before giving up."""
|
|
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
api_key = convert_to_secret_str(
|
|
|
|
api_key = convert_to_secret_str(
|
|
|
|
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
|
|
|
|
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
|
|
|
@ -46,7 +46,13 @@ class AI21Base(BaseModel):
|
|
|
|
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
|
|
|
|
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
values["timeout_sec"] = timeout_sec
|
|
|
|
values["timeout_sec"] = timeout_sec
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=False, skip_on_failure=True)
|
|
|
|
|
|
|
|
def post_init(cls, values: Dict) -> Dict:
|
|
|
|
|
|
|
|
api_key = values["api_key"]
|
|
|
|
|
|
|
|
api_host = values["api_host"]
|
|
|
|
|
|
|
|
timeout_sec = values["timeout_sec"]
|
|
|
|
if values.get("client") is None:
|
|
|
|
if values.get("client") is None:
|
|
|
|
values["client"] = AI21Client(
|
|
|
|
values["client"] = AI21Client(
|
|
|
|
api_key=api_key.get_secret_value(),
|
|
|
|
api_key=api_key.get_secret_value(),
|
|
|
|