Add LLaMa Formatter and AzureML Chat Endpoint (#8382)

## Description

Microsoft and Meta recently [announced their
collaboration](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/)
on LLaMa2. This PR extends the current LLM wrapper and introduces a new
Chat Model wrapper for AzureML to support LLaMa2.

## Dependencies

No dependencies added :)

## Twitter Handles

[@matthew_d13](https://twitter.com/matthew_d13)
[@prakhar_in](https://twitter.com/prakhar_in)

maintainers - @hwchase17, @baskaryan
This commit is contained in:
Matthew DeGuzman 2023-07-31 16:26:25 -07:00 committed by GitHub
parent 1ab773c742
commit 844eca98d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 418 additions and 44 deletions

View File

@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AzureML Chat Online Endpoint\n",
"\n",
"[AzureML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
"\n",
"This notebook goes over how to use a chat model hosted on an `AzureML online endpoint`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up\n",
"\n",
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
"\n",
"* `endpoint_api_key`: The API key provided by the endpoint\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Content Formatter\n",
"\n",
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
"\n",
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2-chat"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' The Collatz Conjecture is one of the most famous unsolved problems in mathematics, and it has been the subject of much study and research for many years. While it is impossible to predict with certainty whether the conjecture will ever be solved, there are several reasons why it is considered a challenging and important problem:\\n\\n1. Simple yet elusive: The Collatz Conjecture is a deceptively simple statement that has proven to be extraordinarily difficult to prove or disprove. Despite its simplicity, the conjecture has eluded some of the brightest minds in mathematics, and it remains one of the most famous open problems in the field.\\n2. Wide-ranging implications: The Collatz Conjecture has far-reaching implications for many areas of mathematics, including number theory, algebra, and analysis. A solution to the conjecture could have significant impacts on these fields and potentially lead to new insights and discoveries.\\n3. Computational evidence: While the conjecture remains unproven, extensive computational evidence supports its validity. In fact, no counterexample to the conjecture has been found for any starting value up to 2^64 (a number', additional_kwargs={}, example=False)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chat_models.azureml_endpoint import LlamaContentFormatter\n",
"from langchain.schema import HumanMessage\n",
"\n",
"chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())\n",
"response = chat(messages=[\n",
" HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")\n",
"])\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -28,9 +28,9 @@
"\n", "\n",
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n", "To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
"\n", "\n",
"* `endpoint_api_key`: The API key provided by the endpoint\n", "* `endpoint_api_key`: Required - The API key provided by the endpoint\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint\n", "* `endpoint_url`: Required - The REST endpoint url provided by the endpoint\n",
"* `deployment_name`: The deployment name of the endpoint" "* `deployment_name`: Not required - The deployment name of the model using the endpoint"
] ]
}, },
{ {
@ -39,11 +39,14 @@
"source": [ "source": [
"## Content Formatter\n", "## Content Formatter\n",
"\n", "\n",
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. Additionally, there are three content formatters already provided:\n", "The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
"\n", "\n",
"* `OSSContentFormatter`: Formats request and response data for models from the Open Source category in the Model Catalog. Note, that not all models in the Open Source category may follow the same schema\n", "* `GPT2ContentFormatter`: Formats request and response data for GPT2\n",
"* `DollyContentFormatter`: Formats request and response data for the `dolly-v2-12b` model\n", "* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n",
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n", "* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
"\n",
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatibile.*\n",
"\n", "\n",
"Below is an example using a summarization model from Hugging Face." "Below is an example using a summarization model from Hugging Face."
] ]
@ -100,7 +103,6 @@
"llm = AzureMLOnlineEndpoint(\n", "llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n", " endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n", " endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
" deployment_name=\"linydub-bart-large-samsum-3\",\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
" content_formatter=content_formatter,\n", " content_formatter=content_formatter,\n",
")\n", ")\n",
@ -167,7 +169,6 @@
"llm = AzureMLOnlineEndpoint(\n", "llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_key=os.getenv(\"DOLLY_ENDPOINT_API_KEY\"),\n", " endpoint_api_key=os.getenv(\"DOLLY_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"DOLLY_ENDPOINT_URL\"),\n", " endpoint_url=os.getenv(\"DOLLY_ENDPOINT_URL\"),\n",
" deployment_name=\"databricks-dolly-v2-12b-4\",\n",
" model_kwargs={\"temperature\": 0.8, \"max_tokens\": 300},\n", " model_kwargs={\"temperature\": 0.8, \"max_tokens\": 300},\n",
" content_formatter=content_formatter,\n", " content_formatter=content_formatter,\n",
")\n", ")\n",

View File

@ -0,0 +1,151 @@
import json
from typing import Any, Dict, List, Optional
from pydantic import validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa"""
SUPPORTED_ROLES = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role"""
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
):
return {"role": message.role, "content": message.content}
else:
supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
def _format_request_payload(
self, messages: List[BaseMessage], model_kwargs: Dict
) -> bytes:
chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message)
for message in messages
]
prompt = json.dumps(
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
)
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according the the chosen api"""
return str.encode(prompt)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)["output"]
class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""Azure ML Chat Online Endpoint models.
Example:
.. code-block:: python
azure_chat = AzureMLChatOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
)
"""
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: str = ""
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exists in environment."""
endpoint_key = get_from_dict_or_env(
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
return http_client
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_chat_endpoint"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
request_payload = self.content_formatter._format_request_payload(
messages, _model_kwargs
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text

View File

@ -1,5 +1,6 @@
import json import json
import urllib.request import urllib.request
import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
@ -14,16 +15,19 @@ class AzureMLEndpointClient(object):
"""AzureML Managed Endpoint client.""" """AzureML Managed Endpoint client."""
def __init__( def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
) -> None: ) -> None:
"""Initialize the class.""" """Initialize the class."""
if not endpoint_api_key: if not endpoint_api_key or not endpoint_url:
raise ValueError("A key should be provided to invoke the endpoint") raise ValueError(
"""A key/token and REST endpoint should
be provided to invoke the endpoint"""
)
self.endpoint_url = endpoint_url self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name self.deployment_name = deployment_name
def call(self, body: bytes) -> bytes: def call(self, body: bytes, **kwargs: Any) -> bytes:
"""call.""" """call."""
# The azureml-model-deployment header will force the request to go to a # The azureml-model-deployment header will force the request to go to a
@ -32,11 +36,12 @@ class AzureMLEndpointClient(object):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": ("Bearer " + self.endpoint_api_key), "Authorization": ("Bearer " + self.endpoint_api_key),
"azureml-model-deployment": self.deployment_name,
} }
if self.deployment_name != "":
headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers) req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=50) response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
result = response.read() result = response.read()
return result return result
@ -75,7 +80,26 @@ class ContentFormatterBase:
"""The MIME type of the input data passed to the endpoint""" """The MIME type of the input data passed to the endpoint"""
accepts: Optional[str] = "application/json" accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned form the endpoint""" """The MIME type of the response data returned from the endpoint"""
@staticmethod
def escape_special_characters(prompt: str) -> str:
"""Escapes any special characters in `prompt`"""
escape_map = {
"\\": "\\\\",
'"': '\\"',
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
}
# Replace each occurrence of the specified characters with escaped versions
for escape_sequence, escaped_sequence in escape_map.items():
prompt = prompt.replace(escape_sequence, escaped_sequence)
return prompt
@abstractmethod @abstractmethod
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
@ -92,44 +116,86 @@ class ContentFormatterBase:
""" """
class OSSContentFormatter(ContentFormatterBase): class GPT2ContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the OSS catalog.""" """Content handler for GPT2"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps( prompt = ContentFormatterBase.escape_special_characters(prompt)
{"inputs": {"input_string": [prompt]}, "parameters": model_kwargs} request_payload = json.dumps(
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
) )
return str.encode(input_str) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output) return json.loads(output)[0]["0"]
return response_json[0]["0"]
class OSSContentFormatter(GPT2ContentFormatter):
"""Deprecated: Kept for backwards compatibility
Content handler for LLMs from the OSS catalog."""
content_formatter: Any = None
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`OSSContentFormatter` will be deprecated in the future.
Please use `GPT2ContentFormatter` instead.
"""
)
class HFContentFormatter(ContentFormatterBase): class HFContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the HuggingFace catalog.""" """Content handler for LLMs from the HuggingFace catalog."""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": [prompt], "parameters": model_kwargs}) ContentFormatterBase.escape_special_characters(prompt)
return str.encode(input_str) request_payload = json.dumps(
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output) return json.loads(output)[0]["generated_text"]
return response_json[0][0]["generated_text"]
class DollyContentFormatter(ContentFormatterBase): class DollyContentFormatter(ContentFormatterBase):
"""Content handler for the Dolly-v2-12b model""" """Content handler for the Dolly-v2-12b model"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps( prompt = ContentFormatterBase.escape_special_characters(prompt)
{"input_data": {"input_string": [prompt]}, "parameters": model_kwargs} request_payload = json.dumps(
{
"input_data": {"input_string": [f'"{prompt}"']},
"parameters": model_kwargs,
}
) )
return str.encode(input_str) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output) return json.loads(output)[0]
return response_json[0]
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according the the chosen api"""
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{
"input_data": {
"input_string": [f'"{prompt}"'],
"parameters": model_kwargs,
}
}
)
return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)[0]["0"]
class AzureMLOnlineEndpoint(LLM, BaseModel): class AzureMLOnlineEndpoint(LLM, BaseModel):
@ -138,10 +204,9 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
Example: Example:
.. code-block:: python .. code-block:: python
azure_llm = AzureMLModel( azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key", endpoint_api_key="my-api-key",
deployment_name="my-deployment-name",
content_formatter=content_formatter, content_formatter=content_formatter,
) )
""" # noqa: E501 """ # noqa: E501
@ -155,8 +220,8 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
env var `AZUREML_ENDPOINT_API_KEY`.""" env var `AZUREML_ENDPOINT_API_KEY`."""
deployment_name: str = "" deployment_name: str = ""
"""Deployment Name for Endpoint. Should be passed to constructor or specified as """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
env var `AZUREML_DEPLOYMENT_NAME`.""" to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
http_client: Any = None #: :meta private: http_client: Any = None #: :meta private:
@ -179,7 +244,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
values, "endpoint_url", "AZUREML_ENDPOINT_URL" values, "endpoint_url", "AZUREML_ENDPOINT_URL"
) )
deployment_name = get_from_dict_or_env( deployment_name = get_from_dict_or_env(
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME" values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
) )
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name) http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
return http_client return http_client
@ -203,7 +268,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
"""Call out to an AzureML Managed Online endpoint. """Call out to an AzureML Managed Online endpoint.
Args: Args:
@ -217,7 +282,11 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
body = self.content_formatter.format_request_payload(prompt, _model_kwargs) request_payload = self.content_formatter.format_request_payload(
endpoint_response = self.http_client.call(body) prompt, _model_kwargs
response = self.content_formatter.format_response_payload(endpoint_response) )
return response response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text

View File

@ -0,0 +1,58 @@
"""Test AzureML Chat Endpoint wrapper."""
from langchain.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
)
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
def test_llama_call() -> None:
"""Test valid call to Open Source Foundation Model."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(messages=[HumanMessage(content="Foo")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_timeout_kwargs() -> None:
"""Test that timeout kwarg works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_message_history() -> None:
"""Test that multiple messages works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
message = HumanMessage(content="Hi!")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@ -18,8 +18,8 @@ from langchain.llms.azureml_endpoint import (
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
def test_oss_call() -> None: def test_gpt2_call() -> None:
"""Test valid call to Open Source Foundation Model.""" """Test valid call to GPT2."""
llm = AzureMLOnlineEndpoint( llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
@ -43,7 +43,7 @@ def test_hf_call() -> None:
def test_dolly_call() -> None: def test_dolly_call() -> None:
"""Test valid call to dolly-v2-12b.""" """Test valid call to dolly-v2."""
llm = AzureMLOnlineEndpoint( llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"), endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"), endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),