Add openai.api_base to support openapi proxy (#2823)

I need access openai api through a proxy, so to add openai.api_base to
support this method.

Co-authored-by: bijia <bijia1@xiaomi.com>
This commit is contained in:
st01cs 2023-04-13 23:35:36 +08:00 committed by GitHub
parent 414dc803b6
commit 4f231b46ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -151,6 +151,7 @@ class BaseOpenAI(BaseLLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None
openai_organization: Optional[str] = None openai_organization: Optional[str] = None
batch_size: int = 20 batch_size: int = 20
"""Batch size to use when passing multiple documents to generate.""" """Batch size to use when passing multiple documents to generate."""
@ -205,6 +206,12 @@ class BaseOpenAI(BaseLLM):
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
openai_api_base = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_organization = get_from_dict_or_env( openai_organization = get_from_dict_or_env(
values, values,
"openai_organization", "openai_organization",
@ -215,6 +222,10 @@ class BaseOpenAI(BaseLLM):
import openai import openai
openai.api_key = openai_api_key openai.api_key = openai_api_key
if openai_api_base:
print("USING API_BASE: ")
print(openai_api_base)
openai.api_base = openai_api_base
if openai_organization: if openai_organization:
print("USING ORGANIZATION: ") print("USING ORGANIZATION: ")
print(openai_organization) print(openai_organization)
@ -567,6 +578,7 @@ class OpenAIChat(BaseLLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None
max_retries: int = 6 max_retries: int = 6
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list) prefix_messages: List = Field(default_factory=list)
@ -599,6 +611,12 @@ class OpenAIChat(BaseLLM):
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
openai_api_base = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_organization = get_from_dict_or_env( openai_organization = get_from_dict_or_env(
values, "openai_organization", "OPENAI_ORGANIZATION", default="" values, "openai_organization", "OPENAI_ORGANIZATION", default=""
) )
@ -606,6 +624,8 @@ class OpenAIChat(BaseLLM):
import openai import openai
openai.api_key = openai_api_key openai.api_key = openai_api_key
if openai_api_base:
openai.api_base = openai_api_base
if openai_organization: if openai_organization:
openai.organization = openai_organization openai.organization = openai_organization
except ImportError: except ImportError: