core[patch]: Allow bound models as token_counter in trim_messages (#25563)

This commit is contained in:
Bagatur 2024-08-20 00:21:22 -07:00 committed by GitHub
parent e01c6789c4
commit 4bd005adb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 10 deletions

View File

@ -514,6 +514,8 @@ def merge_message_runs(
return merged
# TODO: Update so validation errors (for token_counter, for example) are raised on
# init not at runtime.
@_runnable_support
def trim_messages(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
@ -759,24 +761,30 @@ def trim_messages(
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
]
""" # noqa: E501
from langchain_core.language_models import BaseLanguageModel
if start_on and strategy == "first":
raise ValueError
if include_system and strategy == "first":
raise ValueError
messages = convert_to_messages(messages)
if isinstance(token_counter, BaseLanguageModel):
list_token_counter = token_counter.get_num_tokens_from_messages
elif (
list(inspect.signature(token_counter).parameters.values())[0].annotation
is BaseMessage
):
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
elif callable(token_counter):
if (
list(inspect.signature(token_counter).parameters.values())[0].annotation
is BaseMessage
):
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
else:
list_token_counter = token_counter # type: ignore[assignment]
else:
list_token_counter = token_counter # type: ignore[assignment]
raise ValueError(
f"'token_counter' expected ot be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
f"{type(token_counter)}."
)
try:
from langchain_text_splitters import TextSplitter

View File

@ -2,6 +2,7 @@ from typing import Dict, List, Type
import pytest
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None:
assert actual == expected
def test_trim_messages_bound_model_token_counter() -> None:
trimmer = trim_messages(
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
)
trimmer.invoke([HumanMessage("foobar")])
def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={})
with pytest.raises(ValueError):
trimmer.invoke([HumanMessage("foobar")])
def dummy_token_counter(messages: List[BaseMessage]) -> int:
# treat each message like it adds 3 default tokens at the beginning
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int:
+ default_msg_suffix_len
)
return count
class FakeTokenCountingModel(FakeChatModel):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return dummy_token_counter(messages)