Add AzureML endpoint LLM wrapper (#6580)

### Description

We have added a new LLM integration `azureml_endpoint` that allows users
to leverage models from the AzureML platform. Microsoft recently
announced the release of [Azure Foundation

Models](https://learn.microsoft.com/en-us/azure/machine-learning/concept-foundation-models?view=azureml-api-2)
which users can find in the AzureML Model Catalog. The Model Catalog
contains a variety of open source and Hugging Face models that users can
deploy on AzureML. The `azureml_endpoint` allows LangChain users to use
the deployed Azure Foundation Models.

### Dependencies

No added dependencies were required for the change.

### Tests

Integration tests were added in
`tests/integration_tests/llms/test_azureml_endpoint.py`.

### Notebook

A Jupyter notebook demonstrating how to use `azureml_endpoint` was added
to `docs/modules/llms/integrations/azureml_endpoint_example.ipynb`.

### Twitters

[Prakhar Gupta](https://twitter.com/prakhar_in)
[Matthew DeGuzman](https://twitter.com/matthew_d13)

---------

Co-authored-by: Matthew DeGuzman <91019033+matthewdeguzman@users.noreply.github.com>
Co-authored-by: prakharg-msft <75808410+prakharg-msft@users.noreply.github.com>
This commit is contained in:
Davis Chase 2023-06-22 01:46:01 -07:00 committed by GitHub
parent 4fabd02d25
commit d50de2728f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 621 additions and 0 deletions

View File

@ -0,0 +1,243 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AzureML Online Endpoint\n",
"\n",
"[AzureML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
"\n",
"This notebook goes over how to use an LLM hosted on an `AzureML online endpoint`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up\n",
"\n",
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
"\n",
"* `endpoint_api_key`: The API key provided by the endpoint\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint\n",
"* `deployment_name`: The deployment name of the endpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Content Formatter\n",
"\n",
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. Additionally, there are three content formatters already provided:\n",
"\n",
"* `OSSContentFormatter`: Formats request and response data for models from the Open Source category in the Model Catalog. Note, that not all models in the Open Source category may follow the same schema\n",
"* `DollyContentFormatter`: Formats request and response data for the `dolly-v2-12b` model\n",
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
"\n",
"Below is an example using a summarization model from Hugging Face."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Custom Content Formatter"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HaSeul won her first music show trophy with \"So What\" on Mnet's M Countdown. Loona released their second EP titled [#] (read as hash] on February 5, 2020. HaSeul did not take part in the promotion of the album because of mental health issues. On October 19, 2020, they released their third EP called [12:00]. It was their first album to enter the Billboard 200, debuting at number 112. On June 2, 2021, the group released their fourth EP called Yummy-Yummy. On August 27, it was announced that they are making their Japanese debut on September 15 under Universal Music Japan sublabel EMI Records.\n"
]
}
],
"source": [
"from typing import Dict\n",
"\n",
"from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint, ContentFormatterBase\n",
"import os\n",
"import json\n",
"\n",
"\n",
"class CustomFormatter(ContentFormatterBase):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps(\n",
" {\n",
" \"inputs\": [prompt],\n",
" \"parameters\": model_kwargs,\n",
" \"options\": {\"use_cache\": False, \"wait_for_model\": True},\n",
" }\n",
" )\n",
" return str.encode(input_str)\n",
"\n",
" def format_response_payload(self, output: bytes) -> str:\n",
" response_json = json.loads(output)\n",
" return response_json[0][\"summary_text\"]\n",
"\n",
"\n",
"content_formatter = CustomFormatter()\n",
"\n",
"llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
" deployment_name=\"linydub-bart-large-samsum-3\",\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
" content_formatter=content_formatter,\n",
")\n",
"large_text = \"\"\"On January 7, 2020, Blockberry Creative announced that HaSeul would not participate in the promotion for Loona's \n",
"next album because of mental health concerns. She was said to be diagnosed with \"intermittent anxiety symptoms\" and would be \n",
"taking time to focus on her health.[39] On February 5, 2020, Loona released their second EP titled [#] (read as hash), along \n",
"with the title track \"So What\".[40] Although HaSeul did not appear in the title track, her vocals are featured on three other \n",
"songs on the album, including \"365\". Once peaked at number 1 on the daily Gaon Retail Album Chart,[41] the EP then debuted at \n",
"number 2 on the weekly Gaon Album Chart. On March 12, 2020, Loona won their first music show trophy with \"So What\" on Mnet's \n",
"M Countdown.[42]\n",
"\n",
"On October 19, 2020, Loona released their third EP titled [12:00] (read as midnight),[43] accompanied by its first single \n",
"\"Why Not?\". HaSeul was again not involved in the album, out of her own decision to focus on the recovery of her health.[44] \n",
"The EP then became their first album to enter the Billboard 200, debuting at number 112.[45] On November 18, Loona released \n",
"the music video for \"Star\", another song on [12:00].[46] Peaking at number 40, \"Star\" is Loona's first entry on the Billboard \n",
"Mainstream Top 40, making them the second K-pop girl group to enter the chart.[47]\n",
"\n",
"On June 1, 2021, Loona announced that they would be having a comeback on June 28, with their fourth EP, [&] (read as and).\n",
"[48] The following day, on June 2, a teaser was posted to Loona's official social media accounts showing twelve sets of eyes, \n",
"confirming the return of member HaSeul who had been on hiatus since early 2020.[49] On June 12, group members YeoJin, Kim Lip, \n",
"Choerry, and Go Won released the song \"Yum-Yum\" as a collaboration with Cocomong.[50] On September 8, they released another \n",
"collaboration song named \"Yummy-Yummy\".[51] On June 27, 2021, Loona announced at the end of their special clip that they are \n",
"making their Japanese debut on September 15 under Universal Music Japan sublabel EMI Records.[52] On August 27, it was announced \n",
"that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n",
"20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n",
"\"\"\"\n",
"summarized_text = llm(large_text)\n",
"print(summarized_text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dolly with LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Many people are willing to talk about themselves; it's others who seem to be stuck up. Try to understand others where they're coming from. Like minded people can build a tribe together.\n"
]
}
],
"source": [
"from langchain import PromptTemplate\n",
"from langchain.llms.azureml_endpoint import DollyContentFormatter\n",
"from langchain.chains import LLMChain\n",
"\n",
"formatter_template = \"Write a {word_count} word essay about {topic}.\"\n",
"\n",
"prompt = PromptTemplate(\n",
" input_variables=[\"word_count\", \"topic\"], template=formatter_template\n",
")\n",
"\n",
"content_formatter = DollyContentFormatter()\n",
"\n",
"llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_key=os.getenv(\"DOLLY_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"DOLLY_ENDPOINT_URL\"),\n",
" deployment_name=\"databricks-dolly-v2-12b-4\",\n",
" model_kwargs={\"temperature\": 0.8, \"max_tokens\": 300},\n",
" content_formatter=content_formatter,\n",
")\n",
"\n",
"chain = LLMChain(llm=llm, prompt=prompt)\n",
"print(chain.run({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Serializing an LLM\n",
"You can also save and load LLM configurations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mAzureMLOnlineEndpoint\u001b[0m\n",
"Params: {'deployment_name': 'databricks-dolly-v2-12b-4', 'model_kwargs': {'temperature': 0.2, 'max_tokens': 150, 'top_p': 0.8, 'frequency_penalty': 0.32, 'presence_penalty': 0.072}}\n"
]
}
],
"source": [
"from langchain.llms.loading import load_llm\n",
"from langchain.llms.azureml_endpoint import AzureMLEndpointClient\n",
"\n",
"save_llm = AzureMLOnlineEndpoint(\n",
" deployment_name=\"databricks-dolly-v2-12b-4\",\n",
" model_kwargs={\n",
" \"temperature\": 0.2,\n",
" \"max_tokens\": 150,\n",
" \"top_p\": 0.8,\n",
" \"frequency_penalty\": 0.32,\n",
" \"presence_penalty\": 72e-3,\n",
" },\n",
")\n",
"save_llm.save(\"azureml.json\")\n",
"loaded_llm = load_llm(\"azureml.json\")\n",
"\n",
"print(loaded_llm)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -6,6 +6,7 @@ from langchain.llms.aleph_alpha import AlephAlpha
from langchain.llms.anthropic import Anthropic
from langchain.llms.anyscale import Anyscale
from langchain.llms.aviary import Aviary
from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint
from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM
from langchain.llms.baseten import Baseten
@ -54,6 +55,7 @@ __all__ = [
"Anthropic",
"Anyscale",
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
"Banana",
"Baseten",
@ -106,6 +108,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"anyscale": Anyscale,
"aviary": Aviary,
"azure": AzureOpenAI,
"azureml_endpoint": AzureMLOnlineEndpoint,
"bananadev": Banana,
"baseten": Baseten,
"beam": Beam,

View File

@ -0,0 +1,224 @@
"""Wrapper around AzureML Managed Online Endpoint API."""
import json
import urllib.request
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class AzureMLEndpointClient(object):
"""Wrapper around AzureML Managed Online Endpoint Client."""
def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str
) -> None:
"""Initialize the class."""
if not endpoint_api_key:
raise ValueError("A key should be provided to invoke the endpoint")
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
def call(self, body: bytes) -> bytes:
"""call."""
# The azureml-model-deployment header will force the request to go to a
# specific deployment. Remove this header to have the request observe the
# endpoint traffic rules.
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.endpoint_api_key),
"azureml-model-deployment": self.deployment_name,
}
req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=50)
result = response.read()
return result
class ContentFormatterBase:
"""A handler class to transform request and response of
AzureML endpoint to match with required schema.
"""
"""
Example:
.. code-block:: python
class ContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(
self,
prompt: str,
model_kwargs: Dict
) -> bytes:
input_str = json.dumps(
{
"inputs": {"input_string": [prompt]},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
def format_response_payload(self, output: str) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
"""
content_type: Optional[str] = "application/json"
"""The MIME type of the input data passed to the endpoint"""
accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned form the endpoint"""
@abstractmethod
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the
format specified in the content_type request header.
"""
@abstractmethod
def format_response_payload(self, output: bytes) -> str:
"""Formats the response body according to the output
schema of the model. Returns the data type that is
received from the response.
"""
class OSSContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the OSS catalog."""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{"inputs": {"input_string": [prompt]}, "parameters": model_kwargs}
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
class HFContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the HuggingFace catalog."""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": [prompt], "parameters": model_kwargs})
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0][0]["generated_text"]
class DollyContentFormatter(ContentFormatterBase):
"""Content handler for the Dolly-v2-12b model"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{"input_data": {"input_string": [prompt]}, "parameters": model_kwargs}
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]
class AzureMLOnlineEndpoint(LLM, BaseModel):
"""Wrapper around Azure ML Hosted models using Managed Online Endpoints.
Example:
.. code-block:: python
azure_llm = AzureMLModel(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
deployment_name="my-deployment-name",
content_formatter=content_formatter,
)
""" # noqa: E501
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: str = ""
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
deployment_name: str = ""
"""Deployment Name for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_DEPLOYMENT_NAME`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exists in environment."""
endpoint_key = get_from_dict_or_env(
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
deployment_name = get_from_dict_or_env(
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME"
)
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
return http_client
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"deployment_name": self.deployment_name},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_endpoint"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
body = self.content_formatter.format_request_payload(prompt, _model_kwargs)
endpoint_response = self.http_client.call(body)
response = self.content_formatter.format_response_payload(endpoint_response)
return response

View File

@ -0,0 +1,151 @@
"""Test AzureML Endpoint wrapper."""
import json
import os
from pathlib import Path
from typing import Dict
from urllib.request import HTTPError
import pytest
from langchain.llms.azureml_endpoint import (
AzureMLOnlineEndpoint,
ContentFormatterBase,
DollyContentFormatter,
HFContentFormatter,
OSSContentFormatter,
)
from langchain.llms.loading import load_llm
def test_oss_call() -> None:
"""Test valid call to Open Source Foundation Model."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_hf_call() -> None:
"""Test valid call to HuggingFace Foundation Model."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("HF_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("HF_ENDPOINT_URL"),
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
content_formatter=HFContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_dolly_call() -> None:
"""Test valid call to dolly-v2-12b."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
content_formatter=DollyContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{
"inputs": [prompt],
"parameters": model_kwargs,
"options": {"use_cache": False, "wait_for_model": True},
}
)
return input_str.encode("utf-8")
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["summary_text"]
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("BART_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("BART_ENDPOINT_URL"),
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
content_formatter=CustomFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_missing_content_formatter() -> None:
"""Test AzureML LLM without a content_formatter attribute"""
with pytest.raises(AttributeError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
)
llm("Foo")
def test_invalid_request_format() -> None:
"""Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{
"incorrect_input": {"input_string": [prompt]},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=CustomContentFormatter(),
)
llm("Foo")
def test_incorrect_key() -> None:
"""Testing AzureML Endpoint for incorrect key"""
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key="incorrect-key",
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
llm("Foo")
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an AzureML Foundation Model LLM."""
save_llm = AzureMLOnlineEndpoint(
deployment_name="databricks-dolly-v2-12b-4",
model_kwargs={"temperature": 0.03, "top_p": 0.4, "max_tokens": 200},
)
save_llm.save(file_path=tmp_path / "azureml.yaml")
loaded_llm = load_llm(tmp_path / "azureml.yaml")
assert loaded_llm == save_llm