Community: Add mistral oss model support to azureml endpoints, plus configurable timeout (#19123)

- **Description:** There was no formatter for mistral models for Azure
ML endpoints. Adding that, plus a configurable timeout (it was hard
coded before)
- **Dependencies:** none
- **Twitter handle:** @tjaffri @docugami
This commit is contained in:
Taqi Jaffri 2024-03-18 21:10:42 -07:00 committed by GitHub
parent 07de4abe70
commit 044bc22acc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 2 deletions

View File

@ -143,6 +143,44 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter") raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`."""
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [self._convert_message_to_dict(message) for message in messages]
if chat_messages and chat_messages[0]["role"] == "system":
# Mistral OSS models do not explicitly support system prompts, so we have to
# stash in the first user prompt
chat_messages[1]["content"] = (
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
)
del chat_messages[0]
if api_type == AzureMLEndpointApiType.realtime:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint): class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint chat models. """Azure ML Online Endpoint chat models.

View File

@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
DEFAULT_TIMEOUT = 50
class AzureMLEndpointClient(object): class AzureMLEndpointClient(object):
"""AzureML Managed Endpoint client.""" """AzureML Managed Endpoint client."""
def __init__( def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = "" self,
endpoint_url: str,
endpoint_api_key: str,
deployment_name: str = "",
timeout: int = DEFAULT_TIMEOUT,
) -> None: ) -> None:
"""Initialize the class.""" """Initialize the class."""
if not endpoint_api_key or not endpoint_url: if not endpoint_api_key or not endpoint_url:
@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
self.endpoint_url = endpoint_url self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.timeout = timeout
def call( def call(
self, self,
@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
headers["azureml-model-deployment"] = self.deployment_name headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers) req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50)) response = urllib.request.urlopen(
req, timeout=kwargs.get("timeout", self.timeout)
)
result = response.read() result = response.read()
return result return result
@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`.""" to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
timeout: int = DEFAULT_TIMEOUT
"""Request timeout for calls to the endpoint"""
http_client: Any = None #: :meta private: http_client: Any = None #: :meta private:
content_formatter: Any = None content_formatter: Any = None
@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
"AZUREML_ENDPOINT_API_TYPE", "AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.realtime,
) )
values["timeout"] = get_from_dict_or_env(
values,
"timeout",
"AZUREML_TIMEOUT",
str(DEFAULT_TIMEOUT),
)
return values return values
@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
endpoint_url = values.get("endpoint_url") endpoint_url = values.get("endpoint_url")
endpoint_key = values.get("endpoint_api_key") endpoint_key = values.get("endpoint_api_key")
deployment_name = values.get("deployment_name") deployment_name = values.get("deployment_name")
timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient( http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore deployment_name, # type: ignore
timeout, # type: ignore
) )
return http_client return http_client
@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime, endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key", endpoint_api_key="my-api-key",
timeout=120,
content_formatter=content_formatter, content_formatter=content_formatter,
) )
""" # noqa: E501 """ # noqa: E501