mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add LLaMa Formatter and AzureML Chat Endpoint (#8382)
## Description Microsoft and Meta recently [announced their collaboration](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/) on LLaMa2. This PR extends the current LLM wrapper and introduces a new Chat Model wrapper for AzureML to support LLaMa2. ## Dependencies No dependencies added :) ## Twitter Handles [@matthew_d13](https://twitter.com/matthew_d13) [@prakhar_in](https://twitter.com/prakhar_in) maintainers - @hwchase17, @baskaryan
This commit is contained in:
parent
1ab773c742
commit
844eca98d5
95
docs/extras/integrations/chat/azureml_chat_endpoint.ipynb
Normal file
95
docs/extras/integrations/chat/azureml_chat_endpoint.ipynb
Normal file
@ -0,0 +1,95 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AzureML Chat Online Endpoint\n",
|
||||
"\n",
|
||||
"[AzureML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use a chat model hosted on an `AzureML online endpoint`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set up\n",
|
||||
"\n",
|
||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
||||
"\n",
|
||||
"* `endpoint_api_key`: The API key provided by the endpoint\n",
|
||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Content Formatter\n",
|
||||
"\n",
|
||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
||||
"\n",
|
||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2-chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' The Collatz Conjecture is one of the most famous unsolved problems in mathematics, and it has been the subject of much study and research for many years. While it is impossible to predict with certainty whether the conjecture will ever be solved, there are several reasons why it is considered a challenging and important problem:\\n\\n1. Simple yet elusive: The Collatz Conjecture is a deceptively simple statement that has proven to be extraordinarily difficult to prove or disprove. Despite its simplicity, the conjecture has eluded some of the brightest minds in mathematics, and it remains one of the most famous open problems in the field.\\n2. Wide-ranging implications: The Collatz Conjecture has far-reaching implications for many areas of mathematics, including number theory, algebra, and analysis. A solution to the conjecture could have significant impacts on these fields and potentially lead to new insights and discoveries.\\n3. Computational evidence: While the conjecture remains unproven, extensive computational evidence supports its validity. In fact, no counterexample to the conjecture has been found for any starting value up to 2^64 (a number', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models.azureml_endpoint import LlamaContentFormatter\n",
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"\n",
|
||||
"chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())\n",
|
||||
"response = chat(messages=[\n",
|
||||
" HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")\n",
|
||||
"])\n",
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -28,9 +28,9 @@
|
||||
"\n",
|
||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
||||
"\n",
|
||||
"* `endpoint_api_key`: The API key provided by the endpoint\n",
|
||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint\n",
|
||||
"* `deployment_name`: The deployment name of the endpoint"
|
||||
"* `endpoint_api_key`: Required - The API key provided by the endpoint\n",
|
||||
"* `endpoint_url`: Required - The REST endpoint url provided by the endpoint\n",
|
||||
"* `deployment_name`: Not required - The deployment name of the model using the endpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -39,11 +39,14 @@
|
||||
"source": [
|
||||
"## Content Formatter\n",
|
||||
"\n",
|
||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. Additionally, there are three content formatters already provided:\n",
|
||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
||||
"\n",
|
||||
"* `OSSContentFormatter`: Formats request and response data for models from the Open Source category in the Model Catalog. Note, that not all models in the Open Source category may follow the same schema\n",
|
||||
"* `DollyContentFormatter`: Formats request and response data for the `dolly-v2-12b` model\n",
|
||||
"* `GPT2ContentFormatter`: Formats request and response data for GPT2\n",
|
||||
"* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n",
|
||||
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
|
||||
"\n",
|
||||
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatibile.*\n",
|
||||
"\n",
|
||||
"Below is an example using a summarization model from Hugging Face."
|
||||
]
|
||||
@ -100,7 +103,6 @@
|
||||
"llm = AzureMLOnlineEndpoint(\n",
|
||||
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
||||
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
||||
" deployment_name=\"linydub-bart-large-samsum-3\",\n",
|
||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||
" content_formatter=content_formatter,\n",
|
||||
")\n",
|
||||
@ -167,7 +169,6 @@
|
||||
"llm = AzureMLOnlineEndpoint(\n",
|
||||
" endpoint_api_key=os.getenv(\"DOLLY_ENDPOINT_API_KEY\"),\n",
|
||||
" endpoint_url=os.getenv(\"DOLLY_ENDPOINT_URL\"),\n",
|
||||
" deployment_name=\"databricks-dolly-v2-12b-4\",\n",
|
||||
" model_kwargs={\"temperature\": 0.8, \"max_tokens\": 300},\n",
|
||||
" content_formatter=content_formatter,\n",
|
||||
")\n",
|
||||
|
151
libs/langchain/langchain/chat_models/azureml_endpoint.py
Normal file
151
libs/langchain/langchain/chat_models/azureml_endpoint.py
Normal file
@ -0,0 +1,151 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
SUPPORTED_ROLES = ["user", "assistant", "system"]
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict:
|
||||
"""Converts message to a dict according to role"""
|
||||
if isinstance(message, HumanMessage):
|
||||
return {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
return {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return {"role": "system", "content": message.content}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {"role": message.role, "content": message.content}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
|
||||
def _format_request_payload(
|
||||
self, messages: List[BaseMessage], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
chat_messages = [
|
||||
LlamaContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
prompt = json.dumps(
|
||||
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
|
||||
)
|
||||
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according the the chosen api"""
|
||||
return str.encode(prompt)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
"""Formats response"""
|
||||
return json.loads(output)["output"]
|
||||
|
||||
|
||||
class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""Azure ML Chat Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: str = ""
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
content_formatter: Any = None
|
||||
"""The content formatter that provides an input and output
|
||||
transform function to handle formats between the LLM and
|
||||
the endpoint"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_key = get_from_dict_or_env(
|
||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
|
||||
return http_client
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "azureml_chat_endpoint"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
messages: The messages in the conversation with the chat model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
request_payload = self.content_formatter._format_request_payload(
|
||||
messages, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import urllib.request
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
@ -14,16 +15,19 @@ class AzureMLEndpointClient(object):
|
||||
"""AzureML Managed Endpoint client."""
|
||||
|
||||
def __init__(
|
||||
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str
|
||||
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
|
||||
) -> None:
|
||||
"""Initialize the class."""
|
||||
if not endpoint_api_key:
|
||||
raise ValueError("A key should be provided to invoke the endpoint")
|
||||
if not endpoint_api_key or not endpoint_url:
|
||||
raise ValueError(
|
||||
"""A key/token and REST endpoint should
|
||||
be provided to invoke the endpoint"""
|
||||
)
|
||||
self.endpoint_url = endpoint_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
def call(self, body: bytes) -> bytes:
|
||||
def call(self, body: bytes, **kwargs: Any) -> bytes:
|
||||
"""call."""
|
||||
|
||||
# The azureml-model-deployment header will force the request to go to a
|
||||
@ -32,11 +36,12 @@ class AzureMLEndpointClient(object):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": ("Bearer " + self.endpoint_api_key),
|
||||
"azureml-model-deployment": self.deployment_name,
|
||||
}
|
||||
if self.deployment_name != "":
|
||||
headers["azureml-model-deployment"] = self.deployment_name
|
||||
|
||||
req = urllib.request.Request(self.endpoint_url, body, headers)
|
||||
response = urllib.request.urlopen(req, timeout=50)
|
||||
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
|
||||
result = response.read()
|
||||
return result
|
||||
|
||||
@ -75,7 +80,26 @@ class ContentFormatterBase:
|
||||
"""The MIME type of the input data passed to the endpoint"""
|
||||
|
||||
accepts: Optional[str] = "application/json"
|
||||
"""The MIME type of the response data returned form the endpoint"""
|
||||
"""The MIME type of the response data returned from the endpoint"""
|
||||
|
||||
@staticmethod
|
||||
def escape_special_characters(prompt: str) -> str:
|
||||
"""Escapes any special characters in `prompt`"""
|
||||
escape_map = {
|
||||
"\\": "\\\\",
|
||||
'"': '\\"',
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
}
|
||||
|
||||
# Replace each occurrence of the specified characters with escaped versions
|
||||
for escape_sequence, escaped_sequence in escape_map.items():
|
||||
prompt = prompt.replace(escape_sequence, escaped_sequence)
|
||||
|
||||
return prompt
|
||||
|
||||
@abstractmethod
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@ -92,44 +116,86 @@ class ContentFormatterBase:
|
||||
"""
|
||||
|
||||
|
||||
class OSSContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the OSS catalog."""
|
||||
class GPT2ContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for GPT2"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{"inputs": {"input_string": [prompt]}, "parameters": model_kwargs}
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(input_str)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]["0"]
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class OSSContentFormatter(GPT2ContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
Content handler for LLMs from the OSS catalog."""
|
||||
|
||||
content_formatter: Any = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"""`OSSContentFormatter` will be deprecated in the future.
|
||||
Please use `GPT2ContentFormatter` instead.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class HFContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the HuggingFace catalog."""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({"inputs": [prompt], "parameters": model_kwargs})
|
||||
return str.encode(input_str)
|
||||
ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0][0]["generated_text"]
|
||||
return json.loads(output)[0]["generated_text"]
|
||||
|
||||
|
||||
class DollyContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for the Dolly-v2-12b model"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{"input_data": {"input_string": [prompt]}, "parameters": model_kwargs}
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {"input_string": [f'"{prompt}"']},
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
)
|
||||
return str.encode(input_str)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]
|
||||
return json.loads(output)[0]
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according the the chosen api"""
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
"""Formats response"""
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
@ -138,10 +204,9 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_llm = AzureMLModel(
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
deployment_name="my-deployment-name",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
""" # noqa: E501
|
||||
@ -155,8 +220,8 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
deployment_name: str = ""
|
||||
"""Deployment Name for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_DEPLOYMENT_NAME`."""
|
||||
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
|
||||
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
@ -179,7 +244,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
deployment_name = get_from_dict_or_env(
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME"
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
|
||||
return http_client
|
||||
@ -203,7 +268,7 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
@ -217,7 +282,11 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
body = self.content_formatter.format_request_payload(prompt, _model_kwargs)
|
||||
endpoint_response = self.http_client.call(body)
|
||||
response = self.content_formatter.format_response_payload(endpoint_response)
|
||||
return response
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
||||
|
@ -0,0 +1,58 @@
|
||||
"""Test AzureML Chat Endpoint wrapper."""
|
||||
|
||||
from langchain.chat_models.azureml_endpoint import (
|
||||
AzureMLChatOnlineEndpoint,
|
||||
LlamaContentFormatter,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
|
||||
def test_llama_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="Foo")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_timeout_kwargs() -> None:
|
||||
"""Test that timeout kwarg works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_message_history() -> None:
|
||||
"""Test that multiple messages works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
message = HumanMessage(content="Hi!")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@ -18,8 +18,8 @@ from langchain.llms.azureml_endpoint import (
|
||||
from langchain.llms.loading import load_llm
|
||||
|
||||
|
||||
def test_oss_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
def test_gpt2_call() -> None:
|
||||
"""Test valid call to GPT2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
@ -43,7 +43,7 @@ def test_hf_call() -> None:
|
||||
|
||||
|
||||
def test_dolly_call() -> None:
|
||||
"""Test valid call to dolly-v2-12b."""
|
||||
"""Test valid call to dolly-v2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
|
||||
|
Loading…
Reference in New Issue
Block a user