From 044bc22acc8af78eddf97dd6f5fbee4fe7a2e358 Mon Sep 17 00:00:00 2001 From: Taqi Jaffri Date: Mon, 18 Mar 2024 21:10:42 -0700 Subject: [PATCH] Community: Add mistral oss model support to azureml endpoints, plus configurable timeout (#19123) - **Description:** There was no formatter for mistral models for Azure ML endpoints. Adding that, plus a configurable timeout (it was hard coded before) - **Dependencies:** none - **Twitter handle:** @tjaffri @docugami --- .../chat_models/azureml_endpoint.py | 38 +++++++++++++++++++ .../llms/azureml_endpoint.py | 26 ++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/chat_models/azureml_endpoint.py b/libs/community/langchain_community/chat_models/azureml_endpoint.py index dc1385411f..7041759859 100644 --- a/libs/community/langchain_community/chat_models/azureml_endpoint.py +++ b/libs/community/langchain_community/chat_models/azureml_endpoint.py @@ -143,6 +143,44 @@ class LlamaChatContentFormatter(ContentFormatterBase): raise ValueError(f"`api_type` {api_type} is not supported by this formatter") +class MistralChatContentFormatter(LlamaChatContentFormatter): + """Content formatter for `Mistral`.""" + + def format_messages_request_payload( + self, + messages: List[BaseMessage], + model_kwargs: Dict, + api_type: AzureMLEndpointApiType, + ) -> bytes: + """Formats the request according to the chosen api""" + chat_messages = [self._convert_message_to_dict(message) for message in messages] + + if chat_messages and chat_messages[0]["role"] == "system": + # Mistral OSS models do not explicitly support system prompts, so we have to + # stash in the first user prompt + chat_messages[1]["content"] = ( + chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"] + ) + del chat_messages[0] + + if api_type == AzureMLEndpointApiType.realtime: + request_payload = json.dumps( + { + "input_data": { + "input_string": chat_messages, + "parameters": model_kwargs, + } + } + ) + elif api_type == AzureMLEndpointApiType.serverless: + request_payload = json.dumps({"messages": chat_messages, **model_kwargs}) + else: + raise ValueError( + f"`api_type` {api_type} is not supported by this formatter" + ) + return str.encode(request_payload) + + class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint): """Azure ML Online Endpoint chat models. diff --git a/libs/community/langchain_community/llms/azureml_endpoint.py b/libs/community/langchain_community/llms/azureml_endpoint.py index d30612ce8b..da8ef5a053 100644 --- a/libs/community/langchain_community/llms/azureml_endpoint.py +++ b/libs/community/langchain_community/llms/azureml_endpoint.py @@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +DEFAULT_TIMEOUT = 50 + class AzureMLEndpointClient(object): """AzureML Managed Endpoint client.""" def __init__( - self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = "" + self, + endpoint_url: str, + endpoint_api_key: str, + deployment_name: str = "", + timeout: int = DEFAULT_TIMEOUT, ) -> None: """Initialize the class.""" if not endpoint_api_key or not endpoint_url: @@ -27,6 +33,7 @@ class AzureMLEndpointClient(object): self.endpoint_url = endpoint_url self.endpoint_api_key = endpoint_api_key self.deployment_name = deployment_name + self.timeout = timeout def call( self, @@ -47,7 +54,9 @@ class AzureMLEndpointClient(object): headers["azureml-model-deployment"] = self.deployment_name req = urllib.request.Request(self.endpoint_url, body, headers) - response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50)) + response = urllib.request.urlopen( + req, timeout=kwargs.get("timeout", self.timeout) + ) result = response.read() return result @@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel): """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`.""" + timeout: int = DEFAULT_TIMEOUT + """Request timeout for calls to the endpoint""" + http_client: Any = None #: :meta private: content_formatter: Any = None @@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel): "AZUREML_ENDPOINT_API_TYPE", AzureMLEndpointApiType.realtime, ) + values["timeout"] = get_from_dict_or_env( + values, + "timeout", + "AZUREML_TIMEOUT", + str(DEFAULT_TIMEOUT), + ) return values @@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel): endpoint_url = values.get("endpoint_url") endpoint_key = values.get("endpoint_api_key") deployment_name = values.get("deployment_name") + timeout = values.get("timeout", DEFAULT_TIMEOUT) http_client = AzureMLEndpointClient( endpoint_url, # type: ignore endpoint_key.get_secret_value(), # type: ignore deployment_name, # type: ignore + timeout, # type: ignore ) + return http_client @@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint): endpoint_url="https://..inference.ml.azure.com/score", endpoint_api_type=AzureMLApiType.realtime, endpoint_api_key="my-api-key", + timeout=120, content_formatter=content_formatter, ) """ # noqa: E501