From a6a774bfb8d89ee448233df7f0d82bbbaf9b04e4 Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:16:42 -0700 Subject: [PATCH] fix: handle none connectionstr (#110) --- manifest/clients/azureopenai.py | 26 ++++++++++++++------------ manifest/clients/azureopenai_chat.py | 26 ++++++++++++++------------ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/manifest/clients/azureopenai.py b/manifest/clients/azureopenai.py index f1efdc1..2bfb984 100644 --- a/manifest/clients/azureopenai.py +++ b/manifest/clients/azureopenai.py @@ -42,17 +42,19 @@ class AzureClient(OpenAIClient): connection_str: connection string. client_args: client arguments. """ - connection_parts = connection_str.split("::") - if len(connection_parts) == 1: - self.api_key = connection_parts[0] - elif len(connection_parts) == 2: - self.api_key, self.host = connection_parts - else: - raise ValueError( - "Invalid connection string. " - "Must be either AZURE_OPENAI_KEY or " - "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" - ) + self.api_key, self.host = None, None + if connection_str: + connection_parts = connection_str.split("::") + if len(connection_parts) == 1: + self.api_key = connection_parts[0] + elif len(connection_parts) == 2: + self.api_key, self.host = connection_parts + else: + raise ValueError( + "Invalid connection string. " + "Must be either AZURE_OPENAI_KEY or " + "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" + ) self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY") if self.api_key is None: raise ValueError( @@ -60,7 +62,6 @@ class AzureClient(OpenAIClient): "variable or pass through `client_connection`." ) self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT") - self.host = self.host.rstrip("/") if self.host is None: raise ValueError( "Azure Service URL not set " @@ -68,6 +69,7 @@ class AzureClient(OpenAIClient): " Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`." " as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" ) + self.host = self.host.rstrip("/") for key in self.PARAMS: setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) if getattr(self, "engine") not in OPENAI_ENGINES: diff --git a/manifest/clients/azureopenai_chat.py b/manifest/clients/azureopenai_chat.py index dba8324..19d8d76 100644 --- a/manifest/clients/azureopenai_chat.py +++ b/manifest/clients/azureopenai_chat.py @@ -44,17 +44,19 @@ class AzureChatClient(OpenAIChatClient): connection_str: connection string. client_args: client arguments. """ - connection_parts = connection_str.split("::") - if len(connection_parts) == 1: - self.api_key = connection_parts[0] - elif len(connection_parts) == 2: - self.api_key, self.host = connection_parts - else: - raise ValueError( - "Invalid connection string. " - "Must be either AZURE_OPENAI_KEY or " - "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" - ) + self.api_key, self.host = None, None + if connection_str: + connection_parts = connection_str.split("::") + if len(connection_parts) == 1: + self.api_key = connection_parts[0] + elif len(connection_parts) == 2: + self.api_key, self.host = connection_parts + else: + raise ValueError( + "Invalid connection string. " + "Must be either AZURE_OPENAI_KEY or " + "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" + ) self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY") if self.api_key is None: raise ValueError( @@ -62,7 +64,6 @@ class AzureChatClient(OpenAIChatClient): "variable or pass through `client_connection`." ) self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT") - self.host = self.host.rstrip("/") if self.host is None: raise ValueError( "Azure Service URL not set " @@ -70,6 +71,7 @@ class AzureChatClient(OpenAIChatClient): " Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`." " as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT" ) + self.host = self.host.rstrip("/") for key in self.PARAMS: setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) if getattr(self, "engine") not in OPENAICHAT_ENGINES: