From 6691202998cc332506b94e3d2487ca3e4996569a Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:56:47 -0700 Subject: [PATCH] anthropic[patch]: allow multiple sys not at start (#27725) --- .../langchain_anthropic/chat_models.py | 15 ++++++++----- .../tests/unit_tests/test_chat_models.py | 22 +++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index c4fb0b1775..c408cdc23d 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -139,7 +139,10 @@ def _merge_messages( ] ) last = merged[-1] if merged else None - if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): + if any( + all(isinstance(m, c) for m in (curr, last)) + for c in (SystemMessage, HumanMessage) + ): if isinstance(last.content, str): new_content: List = [{"type": "text", "text": last.content}] else: @@ -148,7 +151,7 @@ def _merge_messages( new_content.append({"type": "text", "text": curr.content}) else: new_content.extend(curr.content) - merged[-1] = curr.model_copy(update={"content": new_content}, deep=False) + merged[-1] = curr.model_copy(update={"content": new_content}) else: merged.append(curr) return merged @@ -174,14 +177,14 @@ def _format_messages( merged_messages = _merge_messages(messages) for i, message in enumerate(merged_messages): if message.type == "system": - if i != 0: - raise ValueError("System message must be at beginning of message list.") - if isinstance(message.content, list): + if system is not None: + raise ValueError("Received multiple non-consecutive system messages.") + elif isinstance(message.content, list): system = [ ( block if isinstance(block, dict) - else {"type": "text", "text": "block"} + else {"type": "text", "text": block} ) for block in message.content ] diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 859a2fa5ac..b37246e986 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -695,6 +695,28 @@ def test__format_messages_with_cache_control() -> None: assert expected_messages == actual_messages +def test__format_messages_with_multiple_system() -> None: + messages = [ + HumanMessage("baz"), + SystemMessage("bar"), + SystemMessage("baz"), + SystemMessage( + [ + {"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}}, + ] + ), + ] + expected_system = [ + {"type": "text", "text": "bar"}, + {"type": "text", "text": "baz"}, + {"type": "text", "text": "foo", "cache_control": {"type": "ephemeral"}}, + ] + expected_messages = [{"role": "user", "content": "baz"}] + actual_system, actual_messages = _format_messages(messages) + assert expected_system == actual_system + assert expected_messages == actual_messages + + def test_anthropic_api_key_is_secret_string() -> None: """Test that the API key is stored as a SecretStr.""" chat_model = ChatAnthropic( # type: ignore[call-arg, call-arg]