mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: Allow bound models as token_counter in trim_messages (#25563)
This commit is contained in:
parent
e01c6789c4
commit
4bd005adb6
@ -514,6 +514,8 @@ def merge_message_runs(
|
||||
return merged
|
||||
|
||||
|
||||
# TODO: Update so validation errors (for token_counter, for example) are raised on
|
||||
# init not at runtime.
|
||||
@_runnable_support
|
||||
def trim_messages(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
@ -759,24 +761,30 @@ def trim_messages(
|
||||
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
|
||||
]
|
||||
""" # noqa: E501
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
if start_on and strategy == "first":
|
||||
raise ValueError
|
||||
if include_system and strategy == "first":
|
||||
raise ValueError
|
||||
messages = convert_to_messages(messages)
|
||||
if isinstance(token_counter, BaseLanguageModel):
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
elif (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
is BaseMessage
|
||||
):
|
||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
|
||||
elif callable(token_counter):
|
||||
if (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
is BaseMessage
|
||||
):
|
||||
|
||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
else:
|
||||
list_token_counter = token_counter # type: ignore[assignment]
|
||||
else:
|
||||
list_token_counter = token_counter # type: ignore[assignment]
|
||||
raise ValueError(
|
||||
f"'token_counter' expected ot be a model that implements "
|
||||
f"'get_num_tokens_from_messages()' or a function. Received object of type "
|
||||
f"{type(token_counter)}."
|
||||
)
|
||||
|
||||
try:
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
@ -2,6 +2,7 @@ from typing import Dict, List, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_trim_messages_bound_model_token_counter() -> None:
|
||||
trimmer = trim_messages(
|
||||
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
|
||||
)
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
def test_trim_messages_bad_token_counter() -> None:
|
||||
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||
with pytest.raises(ValueError):
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
||||
# treat each message like it adds 3 default tokens at the beginning
|
||||
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
||||
@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
||||
+ default_msg_suffix_len
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
class FakeTokenCountingModel(FakeChatModel):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return dummy_token_counter(messages)
|
||||
|
Loading…
Reference in New Issue
Block a user