|
|
|
@ -116,11 +116,38 @@ def _convert_mistral_chat_message_to_message(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _raise_on_error(response: httpx.Response) -> None:
|
|
|
|
|
"""Raise an error if the response is an error."""
|
|
|
|
|
if httpx.codes.is_error(response.status_code):
|
|
|
|
|
error_message = response.read().decode("utf-8")
|
|
|
|
|
raise httpx.HTTPStatusError(
|
|
|
|
|
f"Error response {response.status_code} "
|
|
|
|
|
f"while fetching {response.url}: {error_message}",
|
|
|
|
|
request=response.request,
|
|
|
|
|
response=response,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _araise_on_error(response: httpx.Response) -> None:
|
|
|
|
|
"""Raise an error if the response is an error."""
|
|
|
|
|
if httpx.codes.is_error(response.status_code):
|
|
|
|
|
error_message = (await response.aread()).decode("utf-8")
|
|
|
|
|
raise httpx.HTTPStatusError(
|
|
|
|
|
f"Error response {response.status_code} "
|
|
|
|
|
f"while fetching {response.url}: {error_message}",
|
|
|
|
|
request=response.request,
|
|
|
|
|
response=response,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _aiter_sse(
|
|
|
|
|
event_source_mgr: AsyncContextManager[EventSource],
|
|
|
|
|
) -> AsyncIterator[Dict]:
|
|
|
|
|
"""Iterate over the server-sent events."""
|
|
|
|
|
async with event_source_mgr as event_source:
|
|
|
|
|
# TODO(Team): Remove after this is fixed in httpx dependency
|
|
|
|
|
# https://github.com/florimondmanca/httpx-sse/pull/25/files
|
|
|
|
|
await _araise_on_error(event_source._response)
|
|
|
|
|
async for event in event_source.aiter_sse():
|
|
|
|
|
if event.data == "[DONE]":
|
|
|
|
|
return
|
|
|
|
@ -144,10 +171,10 @@ async def acompletion_with_retry(
|
|
|
|
|
event_source = aconnect_sse(
|
|
|
|
|
llm.async_client, "POST", "/chat/completions", json=kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return _aiter_sse(event_source)
|
|
|
|
|
else:
|
|
|
|
|
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
|
|
|
|
await _araise_on_error(response)
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
@ -298,6 +325,9 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
with connect_sse(
|
|
|
|
|
self.client, "POST", "/chat/completions", json=kwargs
|
|
|
|
|
) as event_source:
|
|
|
|
|
# TODO(Team): Remove after this is fixed in httpx dependency
|
|
|
|
|
# https://github.com/florimondmanca/httpx-sse/pull/25/files
|
|
|
|
|
_raise_on_error(event_source._response)
|
|
|
|
|
for event in event_source.iter_sse():
|
|
|
|
|
if event.data == "[DONE]":
|
|
|
|
|
return
|
|
|
|
@ -305,7 +335,9 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
return iter_sse()
|
|
|
|
|
else:
|
|
|
|
|
return self.client.post(url="/chat/completions", json=kwargs).json()
|
|
|
|
|
response = self.client.post(url="/chat/completions", json=kwargs)
|
|
|
|
|
_raise_on_error(response)
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
rtn = _completion_with_retry(**kwargs)
|
|
|
|
|
return rtn
|
|
|
|
|