diff --git a/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb b/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb index 0bca033c6a..0542ba0ef8 100644 --- a/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb +++ b/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb @@ -40,7 +40,7 @@ "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n", "\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n", - "* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", + "* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", "* `endpoint_api_key`: The API key provided by the endpoint" ] }, @@ -52,9 +52,9 @@ "\n", "The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n", "\n", - "* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n", + "* `CustomOpenAIChatContentFormatter`: Formats request and response data for models like LLaMa2-chat that follow the OpenAI API spec for request and response.\n", "\n", - "*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n", + "*Note: `langchain.chat_models.azureml_endpoint.LlamaChatContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.CustomOpenAIChatContentFormatter`.*\n", "\n", "You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`." ] @@ -65,20 +65,7 @@ "source": [ "## Examples\n", "\n", - "The following section cotain examples about how to use this class:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.chat_models.azureml_endpoint import (\n", - " AzureMLEndpointApiType,\n", - " LlamaChatContentFormatter,\n", - ")\n", - "from langchain_core.messages import HumanMessage" + "The following section contains examples about how to use this class:" ] }, { @@ -105,14 +92,17 @@ } ], "source": [ - "from langchain_community.chat_models.azureml_endpoint import LlamaContentFormatter\n", + "from langchain_community.chat_models.azureml_endpoint import (\n", + " AzureMLEndpointApiType,\n", + " CustomOpenAIChatContentFormatter,\n", + ")\n", "from langchain_core.messages import HumanMessage\n", "\n", "chat = AzureMLChatOnlineEndpoint(\n", " endpoint_url=\"https://..inference.ml.azure.com/score\",\n", - " endpoint_api_type=AzureMLEndpointApiType.realtime,\n", + " endpoint_api_type=AzureMLEndpointApiType.dedicated,\n", " endpoint_api_key=\"my-api-key\",\n", - " content_formatter=LlamaChatContentFormatter(),\n", + " content_formatter=CustomOpenAIChatContentFormatter(),\n", ")\n", "response = chat.invoke(\n", " [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n", @@ -137,7 +127,7 @@ " endpoint_url=\"https://..inference.ml.azure.com/v1/chat/completions\",\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_key=\"my-api-key\",\n", - " content_formatter=LlamaChatContentFormatter,\n", + " content_formatter=CustomOpenAIChatContentFormatter,\n", ")\n", "response = chat.invoke(\n", " [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n", @@ -149,7 +139,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you need to pass additional parameters to the model, use `model_kwards` argument:" + "If you need to pass additional parameters to the model, use `model_kwargs` argument:" ] }, { @@ -162,7 +152,7 @@ " endpoint_url=\"https://..inference.ml.azure.com/v1/chat/completions\",\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_key=\"my-api-key\",\n", - " content_formatter=LlamaChatContentFormatter,\n", + " content_formatter=CustomOpenAIChatContentFormatter,\n", " model_kwargs={\"temperature\": 0.8},\n", ")" ] @@ -204,7 +194,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/docs/integrations/llms/azure_ml.ipynb b/docs/docs/integrations/llms/azure_ml.ipynb index bfee9ed3cb..7407560109 100644 --- a/docs/docs/integrations/llms/azure_ml.ipynb +++ b/docs/docs/integrations/llms/azure_ml.ipynb @@ -29,7 +29,7 @@ "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n", "\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n", - "* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", + "* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", "* `endpoint_api_key`: The API key provided by the endpoint.\n", "* `deployment_name`: (Optional) The deployment name of the model using the endpoint." ] @@ -45,7 +45,7 @@ "* `GPT2ContentFormatter`: Formats request and response data for GPT2\n", "* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n", "* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n", - "* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n", + "* `CustomOpenAIContentFormatter`: Formats request and response data for models like LLaMa2 that follow OpenAI API compatible scheme.\n", "\n", "*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*" ] @@ -72,15 +72,15 @@ "source": [ "from langchain_community.llms.azureml_endpoint import (\n", " AzureMLEndpointApiType,\n", - " LlamaContentFormatter,\n", + " CustomOpenAIContentFormatter,\n", ")\n", "from langchain_core.messages import HumanMessage\n", "\n", "llm = AzureMLOnlineEndpoint(\n", " endpoint_url=\"https://..inference.ml.azure.com/score\",\n", - " endpoint_api_type=AzureMLEndpointApiType.realtime,\n", + " endpoint_api_type=AzureMLEndpointApiType.dedicated,\n", " endpoint_api_key=\"my-api-key\",\n", - " content_formatter=LlamaContentFormatter(),\n", + " content_formatter=CustomOpenAIContentFormatter(),\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", ")\n", "response = llm.invoke(\"Write me a song about sparkling water:\")\n", @@ -119,7 +119,7 @@ "source": [ "from langchain_community.llms.azureml_endpoint import (\n", " AzureMLEndpointApiType,\n", - " LlamaContentFormatter,\n", + " CustomOpenAIContentFormatter,\n", ")\n", "from langchain_core.messages import HumanMessage\n", "\n", @@ -127,7 +127,7 @@ " endpoint_url=\"https://..inference.ml.azure.com/v1/completions\",\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_key=\"my-api-key\",\n", - " content_formatter=LlamaContentFormatter(),\n", + " content_formatter=CustomOpenAIContentFormatter(),\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", ")\n", "response = llm.invoke(\"Write me a song about sparkling water:\")\n", @@ -181,7 +181,7 @@ "content_formatter = CustomFormatter()\n", "\n", "llm = AzureMLOnlineEndpoint(\n", - " endpoint_api_type=\"realtime\",\n", + " endpoint_api_type=\"dedicated\",\n", " endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n", " endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", diff --git a/libs/community/langchain_community/chat_models/azureml_endpoint.py b/libs/community/langchain_community/chat_models/azureml_endpoint.py index 7041759859..e2ea9d775c 100644 --- a/libs/community/langchain_community/chat_models/azureml_endpoint.py +++ b/libs/community/langchain_community/chat_models/azureml_endpoint.py @@ -1,16 +1,37 @@ import json -from typing import Any, Dict, List, Optional, cast +import warnings +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, + Type, + cast, +) -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, + ChatMessageChunk, + FunctionMessageChunk, HumanMessage, + HumanMessageChunk, SystemMessage, + SystemMessageChunk, + ToolMessageChunk, ) -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_community.llms.azureml_endpoint import ( AzureMLBaseEndpoint, @@ -25,12 +46,12 @@ class LlamaContentFormatter(ContentFormatterBase): def __init__(self) -> None: raise TypeError( "`LlamaContentFormatter` is deprecated for chat models. Use " - "`LlamaChatContentFormatter` instead." + "`CustomOpenAIContentFormatter` instead." ) -class LlamaChatContentFormatter(ContentFormatterBase): - """Content formatter for `LLaMA`.""" +class CustomOpenAIChatContentFormatter(ContentFormatterBase): + """Chat Content formatter for models with OpenAI like API scheme.""" SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"] @@ -55,7 +76,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): } elif ( isinstance(message, ChatMessage) - and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES + and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES ): return { "role": message.role, @@ -63,7 +84,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): } else: supported = ",".join( - [role for role in LlamaChatContentFormatter.SUPPORTED_ROLES] + [role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES] ) raise ValueError( f"""Received unsupported role. @@ -72,7 +93,7 @@ class LlamaChatContentFormatter(ContentFormatterBase): @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: - return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] + return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless] def format_messages_request_payload( self, @@ -82,10 +103,13 @@ class LlamaChatContentFormatter(ContentFormatterBase): ) -> bytes: """Formats the request according to the chosen api""" chat_messages = [ - LlamaChatContentFormatter._convert_message_to_dict(message) + CustomOpenAIChatContentFormatter._convert_message_to_dict(message) for message in messages ] - if api_type == AzureMLEndpointApiType.realtime: + if api_type in [ + AzureMLEndpointApiType.dedicated, + AzureMLEndpointApiType.realtime, + ]: request_payload = json.dumps( { "input_data": { @@ -105,10 +129,13 @@ class LlamaChatContentFormatter(ContentFormatterBase): def format_response_payload( self, output: bytes, - api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, + api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated, ) -> ChatGeneration: """Formats response""" - if api_type == AzureMLEndpointApiType.realtime: + if api_type in [ + AzureMLEndpointApiType.dedicated, + AzureMLEndpointApiType.realtime, + ]: try: choice = json.loads(output)["output"] except (KeyError, IndexError, TypeError) as e: @@ -143,6 +170,20 @@ class LlamaChatContentFormatter(ContentFormatterBase): raise ValueError(f"`api_type` {api_type} is not supported by this formatter") +class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter): + """Deprecated: Kept for backwards compatibility + + Chat Content formatter for Llama.""" + + def __init__(self) -> None: + super().__init__() + warnings.warn( + """`LlamaChatContentFormatter` will be deprecated in the future. + Please use `CustomOpenAIChatContentFormatter` instead. + """ + ) + + class MistralChatContentFormatter(LlamaChatContentFormatter): """Content formatter for `Mistral`.""" @@ -187,8 +228,8 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint): Example: .. code-block:: python azure_llm = AzureMLOnlineEndpoint( - endpoint_url="https://..inference.ml.azure.com/score", - endpoint_api_type=AzureMLApiType.realtime, + endpoint_url="https://..inference.ml.azure.com/v1/chat/completions", + endpoint_api_type=AzureMLApiType.serverless, endpoint_api_key="my-api-key", content_formatter=chat_content_formatter, ) @@ -239,3 +280,143 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint): response_payload, self.endpoint_api_type ) return ChatResult(generations=[generations]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + self.endpoint_url = self.endpoint_url.replace("/chat/completions", "") + timeout = None if "timeout" not in kwargs else kwargs["timeout"] + + import openai + + params = {} + client_params = { + "api_key": self.endpoint_api_key.get_secret_value(), + "base_url": self.endpoint_url, + "timeout": timeout, + "default_headers": None, + "default_query": None, + "http_client": None, + } + + client = openai.OpenAI(**client_params) + message_dicts = [ + CustomOpenAIChatContentFormatter._convert_message_to_dict(m) + for m in messages + ] + params = {"stream": True, "stop": stop, "model": None, **kwargs} + + default_chunk_class = AIMessageChunk + for chunk in client.chat.completions.create(messages=message_dicts, **params): + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + self.endpoint_url = self.endpoint_url.replace("/chat/completions", "") + timeout = None if "timeout" not in kwargs else kwargs["timeout"] + + import openai + + params = {} + client_params = { + "api_key": self.endpoint_api_key.get_secret_value(), + "base_url": self.endpoint_url, + "timeout": timeout, + "default_headers": None, + "default_query": None, + "http_client": None, + } + + async_client = openai.AsyncOpenAI(**client_params) + message_dicts = [ + CustomOpenAIChatContentFormatter._convert_message_to_dict(m) + for m in messages + ] + params = {"stream": True, "stop": stop, "model": None, **kwargs} + + default_chunk_class = AIMessageChunk + async for chunk in await async_client.chat.completions.create( + messages=message_dicts, **params + ): + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) + if run_manager: + await run_manager.on_llm_new_token( + token=chunk.text, chunk=chunk, logprobs=logprobs + ) + yield chunk + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = cast(str, _dict.get("role")) + content = cast(str, _dict.get("content") or "") + additional_kwargs: Dict = {} + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) diff --git a/libs/community/langchain_community/llms/azureml_endpoint.py b/libs/community/langchain_community/llms/azureml_endpoint.py index da8ef5a053..8b9dcc43fe 100644 --- a/libs/community/langchain_community/llms/azureml_endpoint.py +++ b/libs/community/langchain_community/llms/azureml_endpoint.py @@ -62,12 +62,14 @@ class AzureMLEndpointClient(object): class AzureMLEndpointApiType(str, Enum): - """Azure ML endpoints API types. Use `realtime` for models deployed in hosted - infrastructure, or `serverless` for models deployed as a service with a + """Azure ML endpoints API types. Use `dedicated` for models deployed in hosted + infrastructure (also known as Online Endpoints in Azure Machine Learning), + or `serverless` for models deployed as a service with a pay-as-you-go billing or PTU. """ - realtime = "realtime" + dedicated = "dedicated" + realtime = "realtime" #: Deprecated serverless = "serverless" @@ -141,13 +143,13 @@ class ContentFormatterBase: deploying models using different hosting methods. Each method may have a different API structure.""" - return [AzureMLEndpointApiType.realtime] + return [AzureMLEndpointApiType.dedicated] def format_request_payload( self, prompt: str, model_kwargs: Dict, - api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, + api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated, ) -> Any: """Formats the request body according to the input schema of the model. Returns bytes or seekable file like object in the @@ -159,7 +161,7 @@ class ContentFormatterBase: def format_response_payload( self, output: bytes, - api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, + api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated, ) -> Generation: """Formats the response body according to the output schema of the model. Returns the data type that is @@ -172,7 +174,7 @@ class GPT2ContentFormatter(ContentFormatterBase): @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: - return [AzureMLEndpointApiType.realtime] + return [AzureMLEndpointApiType.dedicated] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType @@ -214,7 +216,7 @@ class HFContentFormatter(ContentFormatterBase): @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: - return [AzureMLEndpointApiType.realtime] + return [AzureMLEndpointApiType.dedicated] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType @@ -240,7 +242,7 @@ class DollyContentFormatter(ContentFormatterBase): @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: - return [AzureMLEndpointApiType.realtime] + return [AzureMLEndpointApiType.dedicated] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType @@ -264,19 +266,22 @@ class DollyContentFormatter(ContentFormatterBase): return Generation(text=choice) -class LlamaContentFormatter(ContentFormatterBase): - """Content formatter for LLaMa""" +class CustomOpenAIContentFormatter(ContentFormatterBase): + """Content formatter for models that use the OpenAI like API scheme.""" @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: - return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] + return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: """Formats the request according to the chosen api""" prompt = ContentFormatterBase.escape_special_characters(prompt) - if api_type == AzureMLEndpointApiType.realtime: + if api_type in [ + AzureMLEndpointApiType.dedicated, + AzureMLEndpointApiType.realtime, + ]: request_payload = json.dumps( { "input_data": { @@ -297,7 +302,10 @@ class LlamaContentFormatter(ContentFormatterBase): self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: """Formats response""" - if api_type == AzureMLEndpointApiType.realtime: + if api_type in [ + AzureMLEndpointApiType.dedicated, + AzureMLEndpointApiType.realtime, + ]: try: choice = json.loads(output)[0]["0"] except (KeyError, IndexError, TypeError) as e: @@ -324,6 +332,22 @@ class LlamaContentFormatter(ContentFormatterBase): raise ValueError(f"`api_type` {api_type} is not supported by this formatter") +class LlamaContentFormatter(CustomOpenAIContentFormatter): + """Deprecated: Kept for backwards compatibility + + Content formatter for Llama.""" + + content_formatter: Any = None + + def __init__(self) -> None: + super().__init__() + warnings.warn( + """`LlamaContentFormatter` will be deprecated in the future. + Please use `CustomOpenAIContentFormatter` instead. + """ + ) + + class AzureMLBaseEndpoint(BaseModel): """Azure ML Online Endpoint models.""" @@ -331,9 +355,9 @@ class AzureMLBaseEndpoint(BaseModel): """URL of pre-existing Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_URL`.""" - endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime + endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated """Type of the endpoint being consumed. Possible values are `serverless` for - pay-as-you-go and `realtime` for real-time endpoints. """ + pay-as-you-go and `dedicated` for dedicated endpoints. """ endpoint_api_key: SecretStr = convert_to_secret_str("") """Authentication Key for Endpoint. Should be passed to constructor or specified as @@ -348,6 +372,8 @@ class AzureMLBaseEndpoint(BaseModel): http_client: Any = None #: :meta private: + max_retries: int = 1 + content_formatter: Any = None """The content formatter that provides an input and output transform function to handle formats between the LLM and @@ -371,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel): values, "endpoint_api_type", "AZUREML_ENDPOINT_API_TYPE", - AzureMLEndpointApiType.realtime, + AzureMLEndpointApiType.dedicated, ) values["timeout"] = get_from_dict_or_env( values, @@ -404,7 +430,7 @@ class AzureMLBaseEndpoint(BaseModel): if field_value.endswith("inference.ml.azure.com"): raise ValueError( "`endpoint_url` should contain the full invocation URL including " - "`/score` for `endpoint_api_type='realtime'` or `/v1/completions` " + "`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` " "or `/v1/chat/completions` for `endpoint_api_type='serverless'`" ) return field_value @@ -415,11 +441,15 @@ class AzureMLBaseEndpoint(BaseModel): ) -> AzureMLEndpointApiType: """Validate that endpoint api type is compatible with the URL format.""" endpoint_url = values.get("endpoint_url") - if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr] - "/score" + if ( + ( + field_value == AzureMLEndpointApiType.dedicated + or field_value == AzureMLEndpointApiType.realtime + ) + and not endpoint_url.endswith("/score") # type: ignore[union-attr] ): raise ValueError( - "Endpoints of type `realtime` should follow the format " + "Endpoints of type `dedicated` should follow the format " "`https://..inference.ml.azure.com/score`." " If your endpoint URL ends with `/v1/completions` or" "`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead." @@ -461,7 +491,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint): .. code-block:: python azure_llm = AzureMLOnlineEndpoint( endpoint_url="https://..inference.ml.azure.com/score", - endpoint_api_type=AzureMLApiType.realtime, + endpoint_api_type=AzureMLApiType.dedicated, endpoint_api_key="my-api-key", timeout=120, content_formatter=content_formatter, diff --git a/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py index 31092d625b..6840f6dda2 100644 --- a/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py +++ b/libs/community/tests/integration_tests/chat_models/test_azureml_endpoint.py @@ -5,13 +5,15 @@ from langchain_core.outputs import ChatGeneration, LLMResult from langchain_community.chat_models.azureml_endpoint import ( AzureMLChatOnlineEndpoint, - LlamaChatContentFormatter, + CustomOpenAIChatContentFormatter, ) def test_llama_call() -> None: """Test valid call to Open Source Foundation Model.""" - chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) + chat = AzureMLChatOnlineEndpoint( + content_formatter=CustomOpenAIChatContentFormatter() + ) response = chat.invoke([HumanMessage(content="Foo")]) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) @@ -19,7 +21,9 @@ def test_llama_call() -> None: def test_temperature_kwargs() -> None: """Test that timeout kwarg works.""" - chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) + chat = AzureMLChatOnlineEndpoint( + content_formatter=CustomOpenAIChatContentFormatter() + ) response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) @@ -27,7 +31,9 @@ def test_temperature_kwargs() -> None: def test_message_history() -> None: """Test that multiple messages works.""" - chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) + chat = AzureMLChatOnlineEndpoint( + content_formatter=CustomOpenAIChatContentFormatter() + ) response = chat.invoke( [ HumanMessage(content="Hello."), @@ -40,7 +46,9 @@ def test_message_history() -> None: def test_multiple_messages() -> None: - chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) + chat = AzureMLChatOnlineEndpoint( + content_formatter=CustomOpenAIChatContentFormatter() + ) message = HumanMessage(content="Hi!") response = chat.generate([[message], [message]]) diff --git a/libs/langchain/langchain/llms/azureml_endpoint.py b/libs/langchain/langchain/llms/azureml_endpoint.py index 931ea2d96b..ce9cd07cf6 100644 --- a/libs/langchain/langchain/llms/azureml_endpoint.py +++ b/libs/langchain/langchain/llms/azureml_endpoint.py @@ -2,10 +2,10 @@ from langchain_community.llms.azureml_endpoint import ( AzureMLEndpointClient, AzureMLOnlineEndpoint, ContentFormatterBase, + CustomOpenAIContentFormatter, DollyContentFormatter, GPT2ContentFormatter, HFContentFormatter, - LlamaContentFormatter, OSSContentFormatter, ) @@ -16,6 +16,6 @@ __all__ = [ "OSSContentFormatter", "HFContentFormatter", "DollyContentFormatter", - "LlamaContentFormatter", + "CustomOpenAIContentFormatter", "AzureMLOnlineEndpoint", ]