mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
Fix Azure ML serverless API request and response formatting in AzureMLOnlineEndpoint to resolve HTTP 400 errors
This commit is contained in:
parent
eef18dec44
commit
330002f5a4
@ -291,7 +291,15 @@ class CustomOpenAIContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
|
||||
request_payload = json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
**model_kwargs
|
||||
})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
@ -313,17 +321,23 @@ class CustomOpenAIContentFormatter(ContentFormatterBase):
|
||||
return Generation(text=choice)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
response_json = json.loads(output)
|
||||
choice = response_json["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was "
|
||||
"received."
|
||||
f"Endpoint response is not well formed for a chat "
|
||||
f"model. Expected `dict` but `{type(choice)}` was "
|
||||
f"received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
# Extracting the 'message.content' field
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "").strip()
|
||||
if not content:
|
||||
raise ValueError("No 'content' field found in the response message.")
|
||||
except (KeyError, IndexError, TypeError, ValueError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(
|
||||
text=choice["text"].strip(),
|
||||
text=content,
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
@ -331,7 +345,6 @@ class CustomOpenAIContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class LlamaContentFormatter(CustomOpenAIContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user