openai[patch]: fix async http client (#19164)

Fix #19116
This commit is contained in:
Bagatur 2024-03-16 17:50:22 -07:00 committed by GitHub
parent 635b3372bd
commit 611d5a1618
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 70 additions and 23 deletions

View File

@ -185,12 +185,17 @@ class AzureChatOpenAI(ChatOpenAI):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
values["client"] = openai.AzureOpenAI(**client_params).chat.completions if not values.get("client"):
values["async_client"] = openai.AsyncAzureOpenAI( sync_specific = {"http_client": values["http_client"]}
**client_params values["client"] = openai.AzureOpenAI(
).chat.completions **client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params, **async_specific
).chat.completions
return values return values
@property @property

View File

@ -313,7 +313,12 @@ class ChatOpenAI(BaseChatModel):
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client.""" """Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -369,14 +374,17 @@ class ChatOpenAI(BaseChatModel):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
if not values.get("client"): if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).chat.completions sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"): if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI( values["async_client"] = openai.AsyncOpenAI(
**client_params **client_params, **async_specific
).chat.completions ).chat.completions
return values return values

View File

@ -139,10 +139,17 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
values["client"] = openai.AzureOpenAI(**client_params).embeddings if not values.get("client"):
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.AzureOpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params, **async_specific
).embeddings
return values return values
@property @property

View File

@ -123,7 +123,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20 retry_max_seconds: int = 20
"""Max number of seconds to wait between retries""" """Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client.""" """Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -218,12 +223,17 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
if not values.get("client"): if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"): if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).embeddings async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
return values return values
@property @property

View File

@ -160,10 +160,17 @@ class AzureOpenAI(BaseOpenAI):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
values["client"] = openai.AzureOpenAI(**client_params).completions if not values.get("client"):
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).completions sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.AzureOpenAI(
**client_params, **sync_specific
).completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params, **async_specific
).completions
return values return values

View File

@ -149,7 +149,12 @@ class BaseOpenAI(BaseLLM):
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client.""" """Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -209,12 +214,17 @@ class BaseOpenAI(BaseLLM):
"max_retries": values["max_retries"], "max_retries": values["max_retries"],
"default_headers": values["default_headers"], "default_headers": values["default_headers"],
"default_query": values["default_query"], "default_query": values["default_query"],
"http_client": values["http_client"],
} }
if not values.get("client"): if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).completions sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).completions
if not values.get("async_client"): if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).completions async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).completions
return values return values