anthropic[patch]: allow multiple sys not at start (#27725)

This commit is contained in:
Bagatur 2024-10-30 16:56:47 -07:00 committed by GitHub
parent 1ed3cd252e
commit 6691202998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 6 deletions

View File

@ -139,7 +139,10 @@ def _merge_messages(
]
)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if any(
all(isinstance(m, c) for m in (curr, last))
for c in (SystemMessage, HumanMessage)
):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
@ -148,7 +151,7 @@ def _merge_messages(
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
merged[-1] = curr.model_copy(update={"content": new_content}, deep=False)
merged[-1] = curr.model_copy(update={"content": new_content})
else:
merged.append(curr)
return merged
@ -174,14 +177,14 @@ def _format_messages(
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if isinstance(message.content, list):
if system is not None:
raise ValueError("Received multiple non-consecutive system messages.")
elif isinstance(message.content, list):
system = [
(
block
if isinstance(block, dict)
else {"type": "text", "text": "block"}
else {"type": "text", "text": block}
)
for block in message.content
]

View File

@ -695,6 +695,28 @@ def test__format_messages_with_cache_control() -> None:
assert expected_messages == actual_messages
def test__format_messages_with_multiple_system() -> None:
messages = [
HumanMessage("baz"),
SystemMessage("bar"),
SystemMessage("baz"),
SystemMessage(
[
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
]
),
]
expected_system = [
{"type": "text", "text": "bar"},
{"type": "text", "text": "baz"},
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
]
expected_messages = [{"role": "user", "content": "baz"}]
actual_system, actual_messages = _format_messages(messages)
assert expected_system == actual_system
assert expected_messages == actual_messages
def test_anthropic_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]