langchain/libs/partners/ai21/langchain_ai21/ai21_base.py
Asaf Joseph Gardin 21c45475c5
ai21[patch]: AI21 Labs bump SDK version (#19114)
Description: Added support AI21 SDK version 2.1.2
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-03-18 19:47:08 -07:00

49 lines
1.4 KiB
Python

import os
from typing import Any, Dict, Optional
from ai21 import AI21Client
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
_DEFAULT_TIMEOUT_SEC = 300
class AI21Base(BaseModel):
class Config:
arbitrary_types_allowed = True
client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: Optional[SecretStr] = None
api_host: Optional[str] = None
timeout_sec: Optional[float] = None
num_retries: Optional[int] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
api_key = convert_to_secret_str(
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
)
values["api_key"] = api_key
api_host = (
values.get("api_host")
or os.getenv("AI21_API_URL")
or "https://api.ai21.com"
)
values["api_host"] = api_host
timeout_sec = values.get("timeout_sec") or float(
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
)
values["timeout_sec"] = timeout_sec
if values.get("client") is None:
values["client"] = AI21Client(
api_key=api_key.get_secret_value(),
api_host=api_host,
timeout_sec=None if timeout_sec is None else float(timeout_sec),
via="langchain",
)
return values