|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
from typing import Any, Dict, List, Literal
|
|
|
|
|
from typing import Any, Dict, List, Literal, Union
|
|
|
|
|
|
|
|
|
|
from langchain_core.messages.base import (
|
|
|
|
|
BaseMessage,
|
|
|
|
@ -69,6 +69,37 @@ class AIMessage(BaseMessage):
|
|
|
|
|
pass
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def pretty_repr(self, html: bool = False) -> str:
|
|
|
|
|
"""Return a pretty representation of the message."""
|
|
|
|
|
base = super().pretty_repr(html=html)
|
|
|
|
|
lines = []
|
|
|
|
|
|
|
|
|
|
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
|
|
|
|
|
lines = [
|
|
|
|
|
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
|
|
|
|
|
f" Call ID: {tc.get('id')}",
|
|
|
|
|
]
|
|
|
|
|
if tc.get("error"):
|
|
|
|
|
lines.append(f" Error: {tc.get('error')}")
|
|
|
|
|
lines.append(" Args:")
|
|
|
|
|
args = tc.get("args")
|
|
|
|
|
if isinstance(args, str):
|
|
|
|
|
lines.append(f" {args}")
|
|
|
|
|
elif isinstance(args, dict):
|
|
|
|
|
for arg, value in args.items():
|
|
|
|
|
lines.append(f" {arg}: {value}")
|
|
|
|
|
return lines
|
|
|
|
|
|
|
|
|
|
if self.tool_calls:
|
|
|
|
|
lines.append("Tool Calls:")
|
|
|
|
|
for tc in self.tool_calls:
|
|
|
|
|
lines.extend(_format_tool_args(tc))
|
|
|
|
|
if self.invalid_tool_calls:
|
|
|
|
|
lines.append("Invalid Tool Calls:")
|
|
|
|
|
for itc in self.invalid_tool_calls:
|
|
|
|
|
lines.extend(_format_tool_args(itc))
|
|
|
|
|
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AIMessage.update_forward_refs()
|
|
|
|
|
|
|
|
|
|