mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
anthropic[patch]: allow multiple sys not at start (#27725)
This commit is contained in:
parent
1ed3cd252e
commit
6691202998
@ -139,7 +139,10 @@ def _merge_messages(
|
||||
]
|
||||
)
|
||||
last = merged[-1] if merged else None
|
||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||
if any(
|
||||
all(isinstance(m, c) for m in (curr, last))
|
||||
for c in (SystemMessage, HumanMessage)
|
||||
):
|
||||
if isinstance(last.content, str):
|
||||
new_content: List = [{"type": "text", "text": last.content}]
|
||||
else:
|
||||
@ -148,7 +151,7 @@ def _merge_messages(
|
||||
new_content.append({"type": "text", "text": curr.content})
|
||||
else:
|
||||
new_content.extend(curr.content)
|
||||
merged[-1] = curr.model_copy(update={"content": new_content}, deep=False)
|
||||
merged[-1] = curr.model_copy(update={"content": new_content})
|
||||
else:
|
||||
merged.append(curr)
|
||||
return merged
|
||||
@ -174,14 +177,14 @@ def _format_messages(
|
||||
merged_messages = _merge_messages(messages)
|
||||
for i, message in enumerate(merged_messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
if isinstance(message.content, list):
|
||||
if system is not None:
|
||||
raise ValueError("Received multiple non-consecutive system messages.")
|
||||
elif isinstance(message.content, list):
|
||||
system = [
|
||||
(
|
||||
block
|
||||
if isinstance(block, dict)
|
||||
else {"type": "text", "text": "block"}
|
||||
else {"type": "text", "text": block}
|
||||
)
|
||||
for block in message.content
|
||||
]
|
||||
|
@ -695,6 +695,28 @@ def test__format_messages_with_cache_control() -> None:
|
||||
assert expected_messages == actual_messages
|
||||
|
||||
|
||||
def test__format_messages_with_multiple_system() -> None:
|
||||
messages = [
|
||||
HumanMessage("baz"),
|
||||
SystemMessage("bar"),
|
||||
SystemMessage("baz"),
|
||||
SystemMessage(
|
||||
[
|
||||
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
|
||||
]
|
||||
),
|
||||
]
|
||||
expected_system = [
|
||||
{"type": "text", "text": "bar"},
|
||||
{"type": "text", "text": "baz"},
|
||||
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
|
||||
]
|
||||
expected_messages = [{"role": "user", "content": "baz"}]
|
||||
actual_system, actual_messages = _format_messages(messages)
|
||||
assert expected_system == actual_system
|
||||
assert expected_messages == actual_messages
|
||||
|
||||
|
||||
def test_anthropic_api_key_is_secret_string() -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
|
||||
|
Loading…
Reference in New Issue
Block a user