mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
50186da0a1
Updating #21137
423 lines
15 KiB
Python
423 lines
15 KiB
Python
import json
|
|
import warnings
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Type,
|
|
cast,
|
|
)
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
BaseMessageChunk,
|
|
ChatMessage,
|
|
ChatMessageChunk,
|
|
FunctionMessageChunk,
|
|
HumanMessage,
|
|
HumanMessageChunk,
|
|
SystemMessage,
|
|
SystemMessageChunk,
|
|
ToolMessageChunk,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_community.llms.azureml_endpoint import (
|
|
AzureMLBaseEndpoint,
|
|
AzureMLEndpointApiType,
|
|
ContentFormatterBase,
|
|
)
|
|
|
|
|
|
class LlamaContentFormatter(ContentFormatterBase):
|
|
"""Content formatter for `LLaMA`."""
|
|
|
|
def __init__(self) -> None:
|
|
raise TypeError(
|
|
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
|
"`CustomOpenAIContentFormatter` instead."
|
|
)
|
|
|
|
|
|
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
|
"""Chat Content formatter for models with OpenAI like API scheme."""
|
|
|
|
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
|
|
|
@staticmethod
|
|
def _convert_message_to_dict(message: BaseMessage) -> Dict:
|
|
"""Converts a message to a dict according to a role"""
|
|
content = cast(str, message.content)
|
|
if isinstance(message, HumanMessage):
|
|
return {
|
|
"role": "user",
|
|
"content": ContentFormatterBase.escape_special_characters(content),
|
|
}
|
|
elif isinstance(message, AIMessage):
|
|
return {
|
|
"role": "assistant",
|
|
"content": ContentFormatterBase.escape_special_characters(content),
|
|
}
|
|
elif isinstance(message, SystemMessage):
|
|
return {
|
|
"role": "system",
|
|
"content": ContentFormatterBase.escape_special_characters(content),
|
|
}
|
|
elif (
|
|
isinstance(message, ChatMessage)
|
|
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
|
|
):
|
|
return {
|
|
"role": message.role,
|
|
"content": ContentFormatterBase.escape_special_characters(content),
|
|
}
|
|
else:
|
|
supported = ",".join(
|
|
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
|
|
)
|
|
raise ValueError(
|
|
f"""Received unsupported role.
|
|
Supported roles for the LLaMa Foundation Model: {supported}"""
|
|
)
|
|
|
|
@property
|
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
|
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
|
|
|
def format_messages_request_payload(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
model_kwargs: Dict,
|
|
api_type: AzureMLEndpointApiType,
|
|
) -> bytes:
|
|
"""Formats the request according to the chosen api"""
|
|
chat_messages = [
|
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
|
|
for message in messages
|
|
]
|
|
if api_type in [
|
|
AzureMLEndpointApiType.dedicated,
|
|
AzureMLEndpointApiType.realtime,
|
|
]:
|
|
request_payload = json.dumps(
|
|
{
|
|
"input_data": {
|
|
"input_string": chat_messages,
|
|
"parameters": model_kwargs,
|
|
}
|
|
}
|
|
)
|
|
elif api_type == AzureMLEndpointApiType.serverless:
|
|
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
|
else:
|
|
raise ValueError(
|
|
f"`api_type` {api_type} is not supported by this formatter"
|
|
)
|
|
return str.encode(request_payload)
|
|
|
|
def format_response_payload(
|
|
self,
|
|
output: bytes,
|
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
|
) -> ChatGeneration:
|
|
"""Formats response"""
|
|
if api_type in [
|
|
AzureMLEndpointApiType.dedicated,
|
|
AzureMLEndpointApiType.realtime,
|
|
]:
|
|
try:
|
|
choice = json.loads(output)["output"]
|
|
except (KeyError, IndexError, TypeError) as e:
|
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
|
return ChatGeneration(
|
|
message=BaseMessage(
|
|
content=choice.strip(),
|
|
type="assistant",
|
|
),
|
|
generation_info=None,
|
|
)
|
|
if api_type == AzureMLEndpointApiType.serverless:
|
|
try:
|
|
choice = json.loads(output)["choices"][0]
|
|
if not isinstance(choice, dict):
|
|
raise TypeError(
|
|
"Endpoint response is not well formed for a chat "
|
|
"model. Expected `dict` but `{type(choice)}` was received."
|
|
)
|
|
except (KeyError, IndexError, TypeError) as e:
|
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
|
return ChatGeneration(
|
|
message=BaseMessage(
|
|
content=choice["message"]["content"].strip(),
|
|
type=choice["message"]["role"],
|
|
),
|
|
generation_info=dict(
|
|
finish_reason=choice.get("finish_reason"),
|
|
logprobs=choice.get("logprobs"),
|
|
),
|
|
)
|
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
|
|
|
|
|
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
|
|
"""Deprecated: Kept for backwards compatibility
|
|
|
|
Chat Content formatter for Llama."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
warnings.warn(
|
|
"""`LlamaChatContentFormatter` will be deprecated in the future.
|
|
Please use `CustomOpenAIChatContentFormatter` instead.
|
|
"""
|
|
)
|
|
|
|
|
|
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
|
"""Content formatter for `Mistral`."""
|
|
|
|
def format_messages_request_payload(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
model_kwargs: Dict,
|
|
api_type: AzureMLEndpointApiType,
|
|
) -> bytes:
|
|
"""Formats the request according to the chosen api"""
|
|
chat_messages = [self._convert_message_to_dict(message) for message in messages]
|
|
|
|
if chat_messages and chat_messages[0]["role"] == "system":
|
|
# Mistral OSS models do not explicitly support system prompts, so we have to
|
|
# stash in the first user prompt
|
|
chat_messages[1]["content"] = (
|
|
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
|
|
)
|
|
del chat_messages[0]
|
|
|
|
if api_type == AzureMLEndpointApiType.realtime:
|
|
request_payload = json.dumps(
|
|
{
|
|
"input_data": {
|
|
"input_string": chat_messages,
|
|
"parameters": model_kwargs,
|
|
}
|
|
}
|
|
)
|
|
elif api_type == AzureMLEndpointApiType.serverless:
|
|
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
|
else:
|
|
raise ValueError(
|
|
f"`api_type` {api_type} is not supported by this formatter"
|
|
)
|
|
return str.encode(request_payload)
|
|
|
|
|
|
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|
"""Azure ML Online Endpoint chat models.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
azure_llm = AzureMLOnlineEndpoint(
|
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
|
|
endpoint_api_type=AzureMLApiType.serverless,
|
|
endpoint_api_key="my-api-key",
|
|
content_formatter=chat_content_formatter,
|
|
)
|
|
"""
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
return {
|
|
**{"model_kwargs": _model_kwargs},
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "azureml_chat_endpoint"
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Call out to an AzureML Managed Online endpoint.
|
|
Args:
|
|
messages: The messages in the conversation with the chat model.
|
|
stop: Optional list of stop words to use when generating.
|
|
Returns:
|
|
The string generated by the model.
|
|
Example:
|
|
.. code-block:: python
|
|
response = azureml_model.invoke("Tell me a joke.")
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
_model_kwargs.update(kwargs)
|
|
if stop:
|
|
_model_kwargs["stop"] = stop
|
|
|
|
request_payload = self.content_formatter.format_messages_request_payload(
|
|
messages, _model_kwargs, self.endpoint_api_type
|
|
)
|
|
response_payload = self.http_client.call(
|
|
body=request_payload, run_manager=run_manager
|
|
)
|
|
generations = self.content_formatter.format_response_payload(
|
|
response_payload, self.endpoint_api_type
|
|
)
|
|
return ChatResult(generations=[generations])
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
|
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
|
|
|
import openai
|
|
|
|
params = {}
|
|
client_params = {
|
|
"api_key": self.endpoint_api_key.get_secret_value(),
|
|
"base_url": self.endpoint_url,
|
|
"timeout": timeout,
|
|
"default_headers": None,
|
|
"default_query": None,
|
|
"http_client": None,
|
|
}
|
|
|
|
client = openai.OpenAI(**client_params)
|
|
message_dicts = [
|
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
|
for m in messages
|
|
]
|
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
for chunk in client.chat.completions.create(messages=message_dicts, **params):
|
|
if not isinstance(chunk, dict):
|
|
chunk = chunk.dict()
|
|
if len(chunk["choices"]) == 0:
|
|
continue
|
|
choice = chunk["choices"][0]
|
|
chunk = _convert_delta_to_message_chunk(
|
|
choice["delta"], default_chunk_class
|
|
)
|
|
generation_info = {}
|
|
if finish_reason := choice.get("finish_reason"):
|
|
generation_info["finish_reason"] = finish_reason
|
|
logprobs = choice.get("logprobs")
|
|
if logprobs:
|
|
generation_info["logprobs"] = logprobs
|
|
default_chunk_class = chunk.__class__
|
|
chunk = ChatGenerationChunk(
|
|
message=chunk, generation_info=generation_info or None
|
|
)
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
|
yield chunk
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
|
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
|
|
|
import openai
|
|
|
|
params = {}
|
|
client_params = {
|
|
"api_key": self.endpoint_api_key.get_secret_value(),
|
|
"base_url": self.endpoint_url,
|
|
"timeout": timeout,
|
|
"default_headers": None,
|
|
"default_query": None,
|
|
"http_client": None,
|
|
}
|
|
|
|
async_client = openai.AsyncOpenAI(**client_params)
|
|
message_dicts = [
|
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
|
for m in messages
|
|
]
|
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
async for chunk in await async_client.chat.completions.create(
|
|
messages=message_dicts, **params
|
|
):
|
|
if not isinstance(chunk, dict):
|
|
chunk = chunk.dict()
|
|
if len(chunk["choices"]) == 0:
|
|
continue
|
|
choice = chunk["choices"][0]
|
|
chunk = _convert_delta_to_message_chunk(
|
|
choice["delta"], default_chunk_class
|
|
)
|
|
generation_info = {}
|
|
if finish_reason := choice.get("finish_reason"):
|
|
generation_info["finish_reason"] = finish_reason
|
|
logprobs = choice.get("logprobs")
|
|
if logprobs:
|
|
generation_info["logprobs"] = logprobs
|
|
default_chunk_class = chunk.__class__
|
|
chunk = ChatGenerationChunk(
|
|
message=chunk, generation_info=generation_info or None
|
|
)
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(
|
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
|
)
|
|
yield chunk
|
|
|
|
|
|
def _convert_delta_to_message_chunk(
|
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
|
) -> BaseMessageChunk:
|
|
role = cast(str, _dict.get("role"))
|
|
content = cast(str, _dict.get("content") or "")
|
|
additional_kwargs: Dict = {}
|
|
if _dict.get("function_call"):
|
|
function_call = dict(_dict["function_call"])
|
|
if "name" in function_call and function_call["name"] is None:
|
|
function_call["name"] = ""
|
|
additional_kwargs["function_call"] = function_call
|
|
if _dict.get("tool_calls"):
|
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
|
|
|
if role == "user" or default_class == HumanMessageChunk:
|
|
return HumanMessageChunk(content=content)
|
|
elif role == "assistant" or default_class == AIMessageChunk:
|
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
|
elif role == "system" or default_class == SystemMessageChunk:
|
|
return SystemMessageChunk(content=content)
|
|
elif role == "function" or default_class == FunctionMessageChunk:
|
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
|
elif role == "tool" or default_class == ToolMessageChunk:
|
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
|
elif role or default_class == ChatMessageChunk:
|
|
return ChatMessageChunk(content=content, role=role)
|
|
else:
|
|
return default_class(content=content) # type: ignore[call-arg]
|