mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Community: Add mistral oss model support to azureml endpoints, plus configurable timeout (#19123)
- **Description:** There was no formatter for mistral models for Azure ML endpoints. Adding that, plus a configurable timeout (it was hard coded before) - **Dependencies:** none - **Twitter handle:** @tjaffri @docugami
This commit is contained in:
parent
07de4abe70
commit
044bc22acc
@ -143,6 +143,44 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||||
|
|
||||||
|
|
||||||
|
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
||||||
|
"""Content formatter for `Mistral`."""
|
||||||
|
|
||||||
|
def format_messages_request_payload(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
model_kwargs: Dict,
|
||||||
|
api_type: AzureMLEndpointApiType,
|
||||||
|
) -> bytes:
|
||||||
|
"""Formats the request according to the chosen api"""
|
||||||
|
chat_messages = [self._convert_message_to_dict(message) for message in messages]
|
||||||
|
|
||||||
|
if chat_messages and chat_messages[0]["role"] == "system":
|
||||||
|
# Mistral OSS models do not explicitly support system prompts, so we have to
|
||||||
|
# stash in the first user prompt
|
||||||
|
chat_messages[1]["content"] = (
|
||||||
|
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
|
||||||
|
)
|
||||||
|
del chat_messages[0]
|
||||||
|
|
||||||
|
if api_type == AzureMLEndpointApiType.realtime:
|
||||||
|
request_payload = json.dumps(
|
||||||
|
{
|
||||||
|
"input_data": {
|
||||||
|
"input_string": chat_messages,
|
||||||
|
"parameters": model_kwargs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif api_type == AzureMLEndpointApiType.serverless:
|
||||||
|
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`api_type` {api_type} is not supported by this formatter"
|
||||||
|
)
|
||||||
|
return str.encode(request_payload)
|
||||||
|
|
||||||
|
|
||||||
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||||
"""Azure ML Online Endpoint chat models.
|
"""Azure ML Online Endpoint chat models.
|
||||||
|
|
||||||
|
@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
|
|||||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
|
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 50
|
||||||
|
|
||||||
|
|
||||||
class AzureMLEndpointClient(object):
|
class AzureMLEndpointClient(object):
|
||||||
"""AzureML Managed Endpoint client."""
|
"""AzureML Managed Endpoint client."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
|
self,
|
||||||
|
endpoint_url: str,
|
||||||
|
endpoint_api_key: str,
|
||||||
|
deployment_name: str = "",
|
||||||
|
timeout: int = DEFAULT_TIMEOUT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the class."""
|
"""Initialize the class."""
|
||||||
if not endpoint_api_key or not endpoint_url:
|
if not endpoint_api_key or not endpoint_url:
|
||||||
@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
|
|||||||
self.endpoint_url = endpoint_url
|
self.endpoint_url = endpoint_url
|
||||||
self.endpoint_api_key = endpoint_api_key
|
self.endpoint_api_key = endpoint_api_key
|
||||||
self.deployment_name = deployment_name
|
self.deployment_name = deployment_name
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
|
|||||||
headers["azureml-model-deployment"] = self.deployment_name
|
headers["azureml-model-deployment"] = self.deployment_name
|
||||||
|
|
||||||
req = urllib.request.Request(self.endpoint_url, body, headers)
|
req = urllib.request.Request(self.endpoint_url, body, headers)
|
||||||
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
|
response = urllib.request.urlopen(
|
||||||
|
req, timeout=kwargs.get("timeout", self.timeout)
|
||||||
|
)
|
||||||
result = response.read()
|
result = response.read()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
|
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
|
||||||
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
|
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
|
||||||
|
|
||||||
|
timeout: int = DEFAULT_TIMEOUT
|
||||||
|
"""Request timeout for calls to the endpoint"""
|
||||||
|
|
||||||
http_client: Any = None #: :meta private:
|
http_client: Any = None #: :meta private:
|
||||||
|
|
||||||
content_formatter: Any = None
|
content_formatter: Any = None
|
||||||
@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
"AZUREML_ENDPOINT_API_TYPE",
|
"AZUREML_ENDPOINT_API_TYPE",
|
||||||
AzureMLEndpointApiType.realtime,
|
AzureMLEndpointApiType.realtime,
|
||||||
)
|
)
|
||||||
|
values["timeout"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"timeout",
|
||||||
|
"AZUREML_TIMEOUT",
|
||||||
|
str(DEFAULT_TIMEOUT),
|
||||||
|
)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
endpoint_url = values.get("endpoint_url")
|
endpoint_url = values.get("endpoint_url")
|
||||||
endpoint_key = values.get("endpoint_api_key")
|
endpoint_key = values.get("endpoint_api_key")
|
||||||
deployment_name = values.get("deployment_name")
|
deployment_name = values.get("deployment_name")
|
||||||
|
timeout = values.get("timeout", DEFAULT_TIMEOUT)
|
||||||
|
|
||||||
http_client = AzureMLEndpointClient(
|
http_client = AzureMLEndpointClient(
|
||||||
endpoint_url, # type: ignore
|
endpoint_url, # type: ignore
|
||||||
endpoint_key.get_secret_value(), # type: ignore
|
endpoint_key.get_secret_value(), # type: ignore
|
||||||
deployment_name, # type: ignore
|
deployment_name, # type: ignore
|
||||||
|
timeout, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
return http_client
|
return http_client
|
||||||
|
|
||||||
|
|
||||||
@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
|||||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||||
endpoint_api_type=AzureMLApiType.realtime,
|
endpoint_api_type=AzureMLApiType.realtime,
|
||||||
endpoint_api_key="my-api-key",
|
endpoint_api_key="my-api-key",
|
||||||
|
timeout=120,
|
||||||
content_formatter=content_formatter,
|
content_formatter=content_formatter,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
Loading…
Reference in New Issue
Block a user