|
|
|
@ -136,38 +136,38 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
|
openai_api_key = get_from_dict_or_env(
|
|
|
|
|
values["openai_api_key"] = get_from_dict_or_env(
|
|
|
|
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
|
|
|
)
|
|
|
|
|
openai_api_base = get_from_dict_or_env(
|
|
|
|
|
values["openai_api_base"] = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"openai_api_base",
|
|
|
|
|
"OPENAI_API_BASE",
|
|
|
|
|
default="",
|
|
|
|
|
)
|
|
|
|
|
openai_api_type = get_from_dict_or_env(
|
|
|
|
|
values["openai_api_type"] = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"openai_api_type",
|
|
|
|
|
"OPENAI_API_TYPE",
|
|
|
|
|
default="",
|
|
|
|
|
)
|
|
|
|
|
openai_proxy = get_from_dict_or_env(
|
|
|
|
|
values["openai_proxy"] = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"openai_proxy",
|
|
|
|
|
"OPENAI_PROXY",
|
|
|
|
|
default="",
|
|
|
|
|
)
|
|
|
|
|
if openai_api_type in ("azure", "azure_ad", "azuread"):
|
|
|
|
|
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
|
|
|
|
|
default_api_version = "2022-12-01"
|
|
|
|
|
else:
|
|
|
|
|
default_api_version = ""
|
|
|
|
|
openai_api_version = get_from_dict_or_env(
|
|
|
|
|
values["openai_api_version"] = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"openai_api_version",
|
|
|
|
|
"OPENAI_API_VERSION",
|
|
|
|
|
default=default_api_version,
|
|
|
|
|
)
|
|
|
|
|
openai_organization = get_from_dict_or_env(
|
|
|
|
|
values["openai_organization"] = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"openai_organization",
|
|
|
|
|
"OPENAI_ORGANIZATION",
|
|
|
|
@ -176,17 +176,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
try:
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
openai.api_key = openai_api_key
|
|
|
|
|
if openai_organization:
|
|
|
|
|
openai.organization = openai_organization
|
|
|
|
|
if openai_api_base:
|
|
|
|
|
openai.api_base = openai_api_base
|
|
|
|
|
if openai_api_type:
|
|
|
|
|
openai.api_version = openai_api_version
|
|
|
|
|
if openai_api_type:
|
|
|
|
|
openai.api_type = openai_api_type
|
|
|
|
|
if openai_proxy:
|
|
|
|
|
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
|
|
|
values["client"] = openai.Embedding
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
@ -195,6 +184,25 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _invocation_params(self) -> Dict:
|
|
|
|
|
openai_args = {
|
|
|
|
|
"engine": self.deployment,
|
|
|
|
|
"request_timeout": self.request_timeout,
|
|
|
|
|
"headers": self.headers,
|
|
|
|
|
"api_key": self.openai_api_key,
|
|
|
|
|
"organization": self.openai_organization,
|
|
|
|
|
"api_base": self.openai_api_base,
|
|
|
|
|
"api_type": self.openai_api_type,
|
|
|
|
|
"api_version": self.openai_api_version,
|
|
|
|
|
}
|
|
|
|
|
if self.openai_proxy:
|
|
|
|
|
openai_args["proxy"] = {
|
|
|
|
|
"http": self.openai_proxy,
|
|
|
|
|
"https": self.openai_proxy,
|
|
|
|
|
}
|
|
|
|
|
return openai_args
|
|
|
|
|
|
|
|
|
|
# please refer to
|
|
|
|
|
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
|
|
|
|
def _get_len_safe_embeddings(
|
|
|
|
@ -233,9 +241,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
response = embed_with_retry(
|
|
|
|
|
self,
|
|
|
|
|
input=tokens[i : i + _chunk_size],
|
|
|
|
|
engine=self.deployment,
|
|
|
|
|
request_timeout=self.request_timeout,
|
|
|
|
|
headers=self.headers,
|
|
|
|
|
**self._invocation_params,
|
|
|
|
|
)
|
|
|
|
|
batched_embeddings += [r["embedding"] for r in response["data"]]
|
|
|
|
|
|
|
|
|
|