fix: handle none connectionstr (#110)

This commit is contained in:
Laurel Orr 2023-07-25 17:16:42 -07:00 committed by GitHub
parent d94101964f
commit a6a774bfb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 24 deletions

View File

@ -42,6 +42,8 @@ class AzureClient(OpenAIClient):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
@ -60,7 +62,6 @@ class AzureClient(OpenAIClient):
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
@ -68,6 +69,7 @@ class AzureClient(OpenAIClient):
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_ENGINES:

View File

@ -44,6 +44,8 @@ class AzureChatClient(OpenAIChatClient):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
@ -62,7 +64,6 @@ class AzureChatClient(OpenAIChatClient):
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
@ -70,6 +71,7 @@ class AzureChatClient(OpenAIChatClient):
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES: