WIP: openai settings (#5792)

[] need to test more
[] make sure they arent saved when serializing
[] do for embeddings
searx_updates
Harrison Chase 11 months ago committed by GitHub
parent b7999a9bc1
commit 3954bcf396
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,33 +53,33 @@ class AzureChatOpenAI(ChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values["openai_api_key"] = get_from_dict_or_env(
values,
"openai_api_key",
"OPENAI_API_KEY",
)
openai_api_base = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
)
openai_api_version = get_from_dict_or_env(
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
)
openai_api_type = get_from_dict_or_env(
values["openai_api_type"] = get_from_dict_or_env(
values,
"openai_api_type",
"OPENAI_API_TYPE",
)
openai_organization = get_from_dict_or_env(
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
openai_proxy = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
@ -88,14 +88,6 @@ class AzureChatOpenAI(ChatOpenAI):
try:
import openai
openai.api_type = openai_api_type
openai.api_base = openai_api_base
openai.api_version = openai_api_version
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
except ImportError:
raise ImportError(
"Could not import openai python package. "
@ -128,6 +120,14 @@ class AzureChatOpenAI(ChatOpenAI):
"""Get the identifying parameters."""
return {**self._default_params}
@property
def _invocation_params(self) -> Mapping[str, Any]:
openai_creds = {
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
return {**openai_creds, **super()._invocation_params}
@property
def _llm_type(self) -> str:
return "azure-openai-chat"

@ -196,22 +196,22 @@ class ChatOpenAI(BaseChatModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_organization = get_from_dict_or_env(
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
openai_api_base = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_proxy = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
@ -225,13 +225,6 @@ class ChatOpenAI(BaseChatModel):
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
if openai_api_base:
openai.api_base = openai_api_base
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
try:
values["client"] = openai.ChatCompletion
except AttributeError:
@ -333,7 +326,7 @@ class ChatOpenAI(BaseChatModel):
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
@ -384,6 +377,21 @@ class ChatOpenAI(BaseChatModel):
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
"model": self.model_name,
}
if self.openai_proxy:
openai_creds["proxy"] = (
{"http": self.openai_proxy, "https": self.openai_proxy},
)
return {**openai_creds, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""

@ -136,38 +136,38 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_api_base = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_api_type = get_from_dict_or_env(
values["openai_api_type"] = get_from_dict_or_env(
values,
"openai_api_type",
"OPENAI_API_TYPE",
default="",
)
openai_proxy = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
if openai_api_type in ("azure", "azure_ad", "azuread"):
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
default_api_version = "2022-12-01"
else:
default_api_version = ""
openai_api_version = get_from_dict_or_env(
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
default=default_api_version,
)
openai_organization = get_from_dict_or_env(
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
@ -176,17 +176,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try:
import openai
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
if openai_api_base:
openai.api_base = openai_api_base
if openai_api_type:
openai.api_version = openai_api_version
if openai_api_type:
openai.api_type = openai_api_type
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
values["client"] = openai.Embedding
except ImportError:
raise ImportError(
@ -195,6 +184,25 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
return values
@property
def _invocation_params(self) -> Dict:
openai_args = {
"engine": self.deployment,
"request_timeout": self.request_timeout,
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
if self.openai_proxy:
openai_args["proxy"] = {
"http": self.openai_proxy,
"https": self.openai_proxy,
}
return openai_args
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
@ -233,9 +241,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
response = embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
engine=self.deployment,
request_timeout=self.request_timeout,
headers=self.headers,
**self._invocation_params,
)
batched_embeddings += [r["embedding"] for r in response["data"]]

@ -211,22 +211,22 @@ class BaseOpenAI(BaseLLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_api_base = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_proxy = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
openai_organization = get_from_dict_or_env(
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
@ -235,13 +235,6 @@ class BaseOpenAI(BaseLLM):
try:
import openai
openai.api_key = openai_api_key
if openai_api_base:
openai.api_base = openai_api_base
if openai_organization:
openai.organization = openai_organization
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
values["client"] = openai.Completion
except ImportError:
raise ImportError(
@ -452,7 +445,17 @@ class BaseOpenAI(BaseLLM):
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
return self._default_params
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
}
if self.openai_proxy:
openai_creds["proxy"] = {
"http": self.openai_proxy,
"https": self.openai_proxy,
}
return {**openai_creds, **self._default_params}
@property
def _identifying_params(self) -> Mapping[str, Any]:
@ -596,6 +599,22 @@ class AzureOpenAI(BaseOpenAI):
deployment_name: str = ""
"""Deployment name to use."""
openai_api_type: str = "azure"
openai_api_version: str = ""
@root_validator()
def validate_azure_settings(cls, values: Dict) -> Dict:
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
)
values["openai_api_type"] = get_from_dict_or_env(
values,
"openai_api_type",
"OPENAI_API_TYPE",
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
@ -606,7 +625,12 @@ class AzureOpenAI(BaseOpenAI):
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**{"engine": self.deployment_name}, **super()._invocation_params}
openai_params = {
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
return {**openai_params, **super()._invocation_params}
@property
def _llm_type(self) -> str:

Loading…
Cancel
Save