mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-18 09:25:48 +00:00
fix: handle none connectionstr (#110)
This commit is contained in:
parent
d94101964f
commit
a6a774bfb8
@ -42,6 +42,8 @@ class AzureClient(OpenAIClient):
|
|||||||
connection_str: connection string.
|
connection_str: connection string.
|
||||||
client_args: client arguments.
|
client_args: client arguments.
|
||||||
"""
|
"""
|
||||||
|
self.api_key, self.host = None, None
|
||||||
|
if connection_str:
|
||||||
connection_parts = connection_str.split("::")
|
connection_parts = connection_str.split("::")
|
||||||
if len(connection_parts) == 1:
|
if len(connection_parts) == 1:
|
||||||
self.api_key = connection_parts[0]
|
self.api_key = connection_parts[0]
|
||||||
@ -60,7 +62,6 @@ class AzureClient(OpenAIClient):
|
|||||||
"variable or pass through `client_connection`."
|
"variable or pass through `client_connection`."
|
||||||
)
|
)
|
||||||
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||||
self.host = self.host.rstrip("/")
|
|
||||||
if self.host is None:
|
if self.host is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Azure Service URL not set "
|
"Azure Service URL not set "
|
||||||
@ -68,6 +69,7 @@ class AzureClient(OpenAIClient):
|
|||||||
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
||||||
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
||||||
)
|
)
|
||||||
|
self.host = self.host.rstrip("/")
|
||||||
for key in self.PARAMS:
|
for key in self.PARAMS:
|
||||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||||
if getattr(self, "engine") not in OPENAI_ENGINES:
|
if getattr(self, "engine") not in OPENAI_ENGINES:
|
||||||
|
@ -44,6 +44,8 @@ class AzureChatClient(OpenAIChatClient):
|
|||||||
connection_str: connection string.
|
connection_str: connection string.
|
||||||
client_args: client arguments.
|
client_args: client arguments.
|
||||||
"""
|
"""
|
||||||
|
self.api_key, self.host = None, None
|
||||||
|
if connection_str:
|
||||||
connection_parts = connection_str.split("::")
|
connection_parts = connection_str.split("::")
|
||||||
if len(connection_parts) == 1:
|
if len(connection_parts) == 1:
|
||||||
self.api_key = connection_parts[0]
|
self.api_key = connection_parts[0]
|
||||||
@ -62,7 +64,6 @@ class AzureChatClient(OpenAIChatClient):
|
|||||||
"variable or pass through `client_connection`."
|
"variable or pass through `client_connection`."
|
||||||
)
|
)
|
||||||
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||||
self.host = self.host.rstrip("/")
|
|
||||||
if self.host is None:
|
if self.host is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Azure Service URL not set "
|
"Azure Service URL not set "
|
||||||
@ -70,6 +71,7 @@ class AzureChatClient(OpenAIChatClient):
|
|||||||
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
||||||
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
||||||
)
|
)
|
||||||
|
self.host = self.host.rstrip("/")
|
||||||
for key in self.PARAMS:
|
for key in self.PARAMS:
|
||||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||||
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
|
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
|
||||||
|
Loading…
Reference in New Issue
Block a user