From b499de29269b9f109240b2a5a7b04e9a47a90408 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:32:24 -0400 Subject: [PATCH] Anthropic system message fix (#11301) Removes human prompt prefix before system message for anthropic models Bedrock anthropic api enforces that Human and Assistant messages must be interleaved (cannot have same type twice in a row). We currently treat System Messages as human messages when converting messages -> string prompt. Our validation when using Bedrock/BedrockChat raises an error when this happens. For ChatAnthropic we don't validate this so no error is raised, but perhaps the behavior is still suboptimal --- .../langchain/chat_models/anthropic.py | 3 +- libs/langchain/langchain/llms/bedrock.py | 4 +-- .../unit_tests/chat_models/test_anthropic.py | 29 ++++++++++++++----- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index 344f4f0b29..24d1d7936c 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -34,7 +34,7 @@ def _convert_one_message_to_text( elif isinstance(message, AIMessage): message_text = f"{ai_prompt} {message.content}" elif isinstance(message, SystemMessage): - message_text = f"{human_prompt} {message.content}" + message_text = message.content else: raise ValueError(f"Got unknown type {message}") return message_text @@ -56,7 +56,6 @@ def convert_messages_to_prompt_anthropic( """ messages = messages.copy() # don't mutate the original list - if not isinstance(messages[-1], AIMessage): messages.append(AIMessage(content="")) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 8bc1472633..6a0f355b34 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -42,12 +42,12 @@ def _human_assistant_format(input_text: str) -> str: if count % 2 == 0: count += 1 else: - raise ValueError(ALTERNATION_ERROR) + raise ValueError(ALTERNATION_ERROR + f" Received {input_text}") if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT: if count % 2 == 1: count += 1 else: - raise ValueError(ALTERNATION_ERROR) + raise ValueError(ALTERNATION_ERROR + f" Received {input_text}") if count % 2 == 1: # Only saw Human, no Assistant input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION diff --git a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py index c60da340a0..d49a3f225d 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py @@ -6,7 +6,7 @@ import pytest from langchain.chat_models import ChatAnthropic from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic -from langchain.schema import AIMessage, BaseMessage, HumanMessage +from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage os.environ["ANTHROPIC_API_KEY"] = "foo" @@ -50,11 +50,24 @@ def test_anthropic_initialization() -> None: ChatAnthropic(model="test", anthropic_api_key="test") -def test_formatting() -> None: - messages: List[BaseMessage] = [HumanMessage(content="Hello")] +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], "\n\nHuman: Hello\n\nAssistant:"), + ( + [HumanMessage(content="Hello"), AIMessage(content="Answer:")], + "\n\nHuman: Hello\n\nAssistant: Answer:", + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Answer:"), + ], + "You're an assistant\n\nHuman: Hello\n\nAssistant: Answer:", + ), + ], +) +def test_formatting(messages: List[BaseMessage], expected: str) -> None: result = convert_messages_to_prompt_anthropic(messages) - assert result == "\n\nHuman: Hello\n\nAssistant:" - - messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")] - result = convert_messages_to_prompt_anthropic(messages) - assert result == "\n\nHuman: Hello\n\nAssistant: Answer:" + assert result == expected