mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
anthropic[patch]: allow multiple sys not at start (#27725)
This commit is contained in:
parent
1ed3cd252e
commit
6691202998
@ -139,7 +139,10 @@ def _merge_messages(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
last = merged[-1] if merged else None
|
last = merged[-1] if merged else None
|
||||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
if any(
|
||||||
|
all(isinstance(m, c) for m in (curr, last))
|
||||||
|
for c in (SystemMessage, HumanMessage)
|
||||||
|
):
|
||||||
if isinstance(last.content, str):
|
if isinstance(last.content, str):
|
||||||
new_content: List = [{"type": "text", "text": last.content}]
|
new_content: List = [{"type": "text", "text": last.content}]
|
||||||
else:
|
else:
|
||||||
@ -148,7 +151,7 @@ def _merge_messages(
|
|||||||
new_content.append({"type": "text", "text": curr.content})
|
new_content.append({"type": "text", "text": curr.content})
|
||||||
else:
|
else:
|
||||||
new_content.extend(curr.content)
|
new_content.extend(curr.content)
|
||||||
merged[-1] = curr.model_copy(update={"content": new_content}, deep=False)
|
merged[-1] = curr.model_copy(update={"content": new_content})
|
||||||
else:
|
else:
|
||||||
merged.append(curr)
|
merged.append(curr)
|
||||||
return merged
|
return merged
|
||||||
@ -174,14 +177,14 @@ def _format_messages(
|
|||||||
merged_messages = _merge_messages(messages)
|
merged_messages = _merge_messages(messages)
|
||||||
for i, message in enumerate(merged_messages):
|
for i, message in enumerate(merged_messages):
|
||||||
if message.type == "system":
|
if message.type == "system":
|
||||||
if i != 0:
|
if system is not None:
|
||||||
raise ValueError("System message must be at beginning of message list.")
|
raise ValueError("Received multiple non-consecutive system messages.")
|
||||||
if isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
system = [
|
system = [
|
||||||
(
|
(
|
||||||
block
|
block
|
||||||
if isinstance(block, dict)
|
if isinstance(block, dict)
|
||||||
else {"type": "text", "text": "block"}
|
else {"type": "text", "text": block}
|
||||||
)
|
)
|
||||||
for block in message.content
|
for block in message.content
|
||||||
]
|
]
|
||||||
|
@ -695,6 +695,28 @@ def test__format_messages_with_cache_control() -> None:
|
|||||||
assert expected_messages == actual_messages
|
assert expected_messages == actual_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test__format_messages_with_multiple_system() -> None:
|
||||||
|
messages = [
|
||||||
|
HumanMessage("baz"),
|
||||||
|
SystemMessage("bar"),
|
||||||
|
SystemMessage("baz"),
|
||||||
|
SystemMessage(
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
expected_system = [
|
||||||
|
{"type": "text", "text": "bar"},
|
||||||
|
{"type": "text", "text": "baz"},
|
||||||
|
{"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}},
|
||||||
|
]
|
||||||
|
expected_messages = [{"role": "user", "content": "baz"}]
|
||||||
|
actual_system, actual_messages = _format_messages(messages)
|
||||||
|
assert expected_system == actual_system
|
||||||
|
assert expected_messages == actual_messages
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_api_key_is_secret_string() -> None:
|
def test_anthropic_api_key_is_secret_string() -> None:
|
||||||
"""Test that the API key is stored as a SecretStr."""
|
"""Test that the API key is stored as a SecretStr."""
|
||||||
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
|
chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]
|
||||||
|
Loading…
Reference in New Issue
Block a user