|
|
@ -211,22 +211,22 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
@root_validator()
|
|
|
|
@root_validator()
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
openai_api_key = get_from_dict_or_env(
|
|
|
|
values["openai_api_key"] = get_from_dict_or_env(
|
|
|
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
|
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
openai_api_base = get_from_dict_or_env(
|
|
|
|
values["openai_api_base"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
values,
|
|
|
|
"openai_api_base",
|
|
|
|
"openai_api_base",
|
|
|
|
"OPENAI_API_BASE",
|
|
|
|
"OPENAI_API_BASE",
|
|
|
|
default="",
|
|
|
|
default="",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
openai_proxy = get_from_dict_or_env(
|
|
|
|
values["openai_proxy"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
values,
|
|
|
|
"openai_proxy",
|
|
|
|
"openai_proxy",
|
|
|
|
"OPENAI_PROXY",
|
|
|
|
"OPENAI_PROXY",
|
|
|
|
default="",
|
|
|
|
default="",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
openai_organization = get_from_dict_or_env(
|
|
|
|
values["openai_organization"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
values,
|
|
|
|
"openai_organization",
|
|
|
|
"openai_organization",
|
|
|
|
"OPENAI_ORGANIZATION",
|
|
|
|
"OPENAI_ORGANIZATION",
|
|
|
@ -235,13 +235,6 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
import openai
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
|
|
openai.api_key = openai_api_key
|
|
|
|
|
|
|
|
if openai_api_base:
|
|
|
|
|
|
|
|
openai.api_base = openai_api_base
|
|
|
|
|
|
|
|
if openai_organization:
|
|
|
|
|
|
|
|
openai.organization = openai_organization
|
|
|
|
|
|
|
|
if openai_proxy:
|
|
|
|
|
|
|
|
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
|
|
|
|
|
|
values["client"] = openai.Completion
|
|
|
|
values["client"] = openai.Completion
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
raise ImportError(
|
|
|
@ -452,7 +445,17 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
"""Get the parameters used to invoke the model."""
|
|
|
|
"""Get the parameters used to invoke the model."""
|
|
|
|
return self._default_params
|
|
|
|
openai_creds: Dict[str, Any] = {
|
|
|
|
|
|
|
|
"api_key": self.openai_api_key,
|
|
|
|
|
|
|
|
"api_base": self.openai_api_base,
|
|
|
|
|
|
|
|
"organization": self.openai_organization,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.openai_proxy:
|
|
|
|
|
|
|
|
openai_creds["proxy"] = {
|
|
|
|
|
|
|
|
"http": self.openai_proxy,
|
|
|
|
|
|
|
|
"https": self.openai_proxy,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return {**openai_creds, **self._default_params}
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
@ -596,6 +599,22 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
|
|
|
|
|
|
|
deployment_name: str = ""
|
|
|
|
deployment_name: str = ""
|
|
|
|
"""Deployment name to use."""
|
|
|
|
"""Deployment name to use."""
|
|
|
|
|
|
|
|
openai_api_type: str = "azure"
|
|
|
|
|
|
|
|
openai_api_version: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
|
|
|
def validate_azure_settings(cls, values: Dict) -> Dict:
|
|
|
|
|
|
|
|
values["openai_api_version"] = get_from_dict_or_env(
|
|
|
|
|
|
|
|
values,
|
|
|
|
|
|
|
|
"openai_api_version",
|
|
|
|
|
|
|
|
"OPENAI_API_VERSION",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
values["openai_api_type"] = get_from_dict_or_env(
|
|
|
|
|
|
|
|
values,
|
|
|
|
|
|
|
|
"openai_api_type",
|
|
|
|
|
|
|
|
"OPENAI_API_TYPE",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
@ -606,7 +625,12 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
return {**{"engine": self.deployment_name}, **super()._invocation_params}
|
|
|
|
openai_params = {
|
|
|
|
|
|
|
|
"engine": self.deployment_name,
|
|
|
|
|
|
|
|
"api_type": self.openai_api_type,
|
|
|
|
|
|
|
|
"api_version": self.openai_api_version,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return {**openai_params, **super()._invocation_params}
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|