From 330002f5a449e435ca90f155fe87001d47ae2555 Mon Sep 17 00:00:00 2001 From: KoKo Mexcelsa Date: Fri, 20 Sep 2024 04:33:15 +0800 Subject: [PATCH] Fix Azure ML serverless API request and response formatting in AzureMLOnlineEndpoint to resolve HTTP 400 errors --- .../llms/azureml_endpoint.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/libs/community/langchain_community/llms/azureml_endpoint.py b/libs/community/langchain_community/llms/azureml_endpoint.py index f2b5eed2e1..0230368a79 100644 --- a/libs/community/langchain_community/llms/azureml_endpoint.py +++ b/libs/community/langchain_community/llms/azureml_endpoint.py @@ -291,7 +291,15 @@ class CustomOpenAIContentFormatter(ContentFormatterBase): } ) elif api_type == AzureMLEndpointApiType.serverless: - request_payload = json.dumps({"prompt": prompt, **model_kwargs}) + request_payload = json.dumps({ + "messages": [ + { + "role": "user", + "content": prompt + } + ], + **model_kwargs + }) else: raise ValueError( f"`api_type` {api_type} is not supported by this formatter" @@ -313,17 +321,23 @@ class CustomOpenAIContentFormatter(ContentFormatterBase): return Generation(text=choice) if api_type == AzureMLEndpointApiType.serverless: try: - choice = json.loads(output)["choices"][0] + response_json = json.loads(output) + choice = response_json["choices"][0] if not isinstance(choice, dict): raise TypeError( - "Endpoint response is not well formed for a chat " - "model. Expected `dict` but `{type(choice)}` was " - "received." + f"Endpoint response is not well formed for a chat " + f"model. Expected `dict` but `{type(choice)}` was " + f"received." ) - except (KeyError, IndexError, TypeError) as e: - raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] + # Extracting the 'message.content' field + message = choice.get("message", {}) + content = message.get("content", "").strip() + if not content: + raise ValueError("No 'content' field found in the response message.") + except (KeyError, IndexError, TypeError, ValueError) as e: + raise ValueError(self.format_error_msg.format(api_type=api_type)) from e return Generation( - text=choice["text"].strip(), + text=content, generation_info=dict( finish_reason=choice.get("finish_reason"), logprobs=choice.get("logprobs"), @@ -331,7 +345,6 @@ class CustomOpenAIContentFormatter(ContentFormatterBase): ) raise ValueError(f"`api_type` {api_type} is not supported by this formatter") - class LlamaContentFormatter(CustomOpenAIContentFormatter): """Deprecated: Kept for backwards compatibility