From 23065f54c047c71c9d71b724013f9652f9867930 Mon Sep 17 00:00:00 2001 From: Jonathan Evans Date: Wed, 27 Sep 2023 22:20:07 -0400 Subject: [PATCH] Added prompt wrapping for Claude with Bedrock (#11090) - **Description:** Prompt wrapping requirements have been implemented on the service side of AWS Bedrock for the Anthropic Claude models to provide parity between Anthropic's offering and Bedrock's offering. This overnight change broke most existing implementations of Claude, Bedrock and Langchain. This PR just steals the the Anthropic LLM implementation to enforce alias/role wrapping and implements it in the existing mechanism for building the request body. This has also been tested to fix the chat_model implementation as well. Happy to answer any further questions or make changes where necessary to get things patched and up to PyPi ASAP, TY. - **Issue:** No issue opened at the moment, though will update when these roll in. - **Dependencies:** None --------- Co-authored-by: Harrison Chase Co-authored-by: Bagatur --- libs/langchain/langchain/llms/bedrock.py | 50 +++- .../tests/unit_tests/llms/test_bedrock.py | 252 ++++++++++++++++++ 2 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 libs/langchain/tests/unit_tests/llms/test_bedrock.py diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 7d07f92662..8fd22351a7 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -8,6 +8,52 @@ from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema.output import GenerationChunk +HUMAN_PROMPT = "\n\nHuman:" +ASSISTANT_PROMPT = "\n\nAssistant:" +ALTERNATION_ERROR = ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." +) + + +def _add_newlines_before_ha(input_text: str) -> str: + new_text = input_text + for word in ["Human:", "Assistant:"]: + new_text = new_text.replace(word, "\n\n" + word) + for i in range(2): + new_text = new_text.replace("\n\n\n" + word, "\n\n" + word) + return new_text + + +def _human_assistant_format(input_text: str) -> str: + if input_text.count("Human:") == 0 or ( + input_text.find("Human:") > input_text.find("Assistant:") + and "Assistant:" in input_text + ): + input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION + if input_text.count("Assistant:") == 0: + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + if input_text[: len("Human:")] == "Human:": + input_text = "\n\n" + input_text + input_text = _add_newlines_before_ha(input_text) + count = 0 + # track alternation + for i in range(len(input_text)): + if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT: + if count % 2 == 0: + count += 1 + else: + raise ValueError(ALTERNATION_ERROR) + if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT: + if count % 2 == 1: + count += 1 + else: + raise ValueError(ALTERNATION_ERROR) + + if count % 2 == 1: # Only saw Human, no Assistant + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + + return input_text + class LLMInputOutputAdapter: """Adapter class to prepare the inputs from Langchain to a format @@ -26,7 +72,9 @@ class LLMInputOutputAdapter: cls, provider: str, prompt: str, model_kwargs: Dict[str, Any] ) -> Dict[str, Any]: input_body = {**model_kwargs} - if provider == "anthropic" or provider == "ai21": + if provider == "anthropic": + input_body["prompt"] = _human_assistant_format(prompt) + elif provider == "ai21": input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() diff --git a/libs/langchain/tests/unit_tests/llms/test_bedrock.py b/libs/langchain/tests/unit_tests/llms/test_bedrock.py new file mode 100644 index 0000000000..985011c277 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_bedrock.py @@ -0,0 +1,252 @@ +import pytest + +from langchain.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format + +TEST_CASES = { + """Hey""": """ + +Human: Hey + +Assistant:""", + """ + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Human: Hello + +Assistant:""": ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." + ), + """Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant: Hello + +Assistant: Hello""": ALTERNATION_ERROR, + """ + +Human: Hi. + +Assistant: Hi. + +Human: Hi. + +Human: Hi. + +Assistant:""": ALTERNATION_ERROR, + """ +Human: Hello""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Hello +Hello + +Assistant""": """ + +Human: Hello +Hello + +Assistant + +Assistant:""", + """Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Hello + +Human: Hello + +""": """Hello + +Human: Hello + + + +Assistant:""", + """ + +Human: Assistant: Hello""": """ + +Human: + +Assistant: Hello""", + """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""": """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""", + """ +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant.""": """ + +Human: + +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant. + +Assistant:""", + """ + +Human: Human: Hi + +Assistant: Hi""": ALTERNATION_ERROR, + """Human: Hi + +Human: Hi""": ALTERNATION_ERROR, + """ + +Assistant: Hi + +Human: Hi""": """ + +Human: + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi +Human: Hi +Assistant:""": """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Hello. + +Human: Hello. + +Assistant:""": """ + +Hello. + +Human: Hello. + +Assistant:""", +} + + +def test__human_assistant_format() -> None: + for input_text, expected_output in TEST_CASES.items(): + if expected_output == ALTERNATION_ERROR: + with pytest.raises(ValueError): + _human_assistant_format(input_text) + else: + output = _human_assistant_format(input_text) + assert output == expected_output