|
|
|
@ -25,12 +25,12 @@ from langchain_core.messages.function import FunctionMessage, FunctionMessageChu
|
|
|
|
|
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
|
|
|
|
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
|
|
|
|
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
|
|
|
|
from langchain_core.runnables import Runnable, RunnableLambda
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from langchain_text_splitters import TextSplitter
|
|
|
|
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
from langchain_core.runnables.base import Runnable
|
|
|
|
|
|
|
|
|
|
AnyMessage = Union[
|
|
|
|
|
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
|
|
|
@ -279,6 +279,8 @@ def _runnable_support(func: Callable) -> Callable:
|
|
|
|
|
List[BaseMessage],
|
|
|
|
|
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]],
|
|
|
|
|
]:
|
|
|
|
|
from langchain_core.runnables.base import RunnableLambda
|
|
|
|
|
|
|
|
|
|
if messages is not None:
|
|
|
|
|
return func(messages, **kwargs)
|
|
|
|
|
else:
|
|
|
|
@ -486,9 +488,7 @@ def trim_messages(
|
|
|
|
|
] = None,
|
|
|
|
|
include_system: bool = False,
|
|
|
|
|
text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None,
|
|
|
|
|
) -> Union[
|
|
|
|
|
List[BaseMessage], Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]
|
|
|
|
|
]:
|
|
|
|
|
) -> List[BaseMessage]:
|
|
|
|
|
"""Trim messages to be below a token count.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -734,53 +734,6 @@ def trim_messages(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
if messages is not None:
|
|
|
|
|
return _trim_messages_helper(
|
|
|
|
|
messages,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
token_counter=token_counter,
|
|
|
|
|
strategy=strategy,
|
|
|
|
|
allow_partial=allow_partial,
|
|
|
|
|
end_on=end_on,
|
|
|
|
|
start_on=start_on,
|
|
|
|
|
include_system=include_system,
|
|
|
|
|
text_splitter=text_splitter,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
trimmer = partial(
|
|
|
|
|
_trim_messages_helper,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
token_counter=token_counter,
|
|
|
|
|
strategy=strategy,
|
|
|
|
|
allow_partial=allow_partial,
|
|
|
|
|
end_on=end_on,
|
|
|
|
|
start_on=start_on,
|
|
|
|
|
include_system=include_system,
|
|
|
|
|
text_splitter=text_splitter,
|
|
|
|
|
)
|
|
|
|
|
return RunnableLambda(trimmer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _trim_messages_helper(
|
|
|
|
|
messages: Sequence[MessageLikeRepresentation],
|
|
|
|
|
*,
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
token_counter: Union[
|
|
|
|
|
Callable[[List[BaseMessage]], int],
|
|
|
|
|
Callable[[BaseMessage], int],
|
|
|
|
|
BaseLanguageModel,
|
|
|
|
|
],
|
|
|
|
|
strategy: Literal["first", "last"] = "last",
|
|
|
|
|
allow_partial: bool = False,
|
|
|
|
|
end_on: Optional[
|
|
|
|
|
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
|
|
|
|
] = None,
|
|
|
|
|
start_on: Optional[
|
|
|
|
|
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
|
|
|
|
] = None,
|
|
|
|
|
include_system: bool = False,
|
|
|
|
|
text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None,
|
|
|
|
|
) -> List[BaseMessage]:
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
|
|
|
if start_on and strategy == "first":
|
|
|
|
|