diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 7d07f92662..8fd22351a7 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -8,6 +8,52 @@ from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema.output import GenerationChunk +HUMAN_PROMPT = "\n\nHuman:" +ASSISTANT_PROMPT = "\n\nAssistant:" +ALTERNATION_ERROR = ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." +) + + +def _add_newlines_before_ha(input_text: str) -> str: + new_text = input_text + for word in ["Human:", "Assistant:"]: + new_text = new_text.replace(word, "\n\n" + word) + for i in range(2): + new_text = new_text.replace("\n\n\n" + word, "\n\n" + word) + return new_text + + +def _human_assistant_format(input_text: str) -> str: + if input_text.count("Human:") == 0 or ( + input_text.find("Human:") > input_text.find("Assistant:") + and "Assistant:" in input_text + ): + input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION + if input_text.count("Assistant:") == 0: + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + if input_text[: len("Human:")] == "Human:": + input_text = "\n\n" + input_text + input_text = _add_newlines_before_ha(input_text) + count = 0 + # track alternation + for i in range(len(input_text)): + if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT: + if count % 2 == 0: + count += 1 + else: + raise ValueError(ALTERNATION_ERROR) + if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT: + if count % 2 == 1: + count += 1 + else: + raise ValueError(ALTERNATION_ERROR) + + if count % 2 == 1: # Only saw Human, no Assistant + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + + return input_text + class LLMInputOutputAdapter: """Adapter class to prepare the inputs from Langchain to a format @@ -26,7 +72,9 @@ class LLMInputOutputAdapter: cls, provider: str, prompt: str, model_kwargs: Dict[str, Any] ) -> Dict[str, Any]: input_body = {**model_kwargs} - if provider == "anthropic" or provider == "ai21": + if provider == "anthropic": + input_body["prompt"] = _human_assistant_format(prompt) + elif provider == "ai21": input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() diff --git a/libs/langchain/tests/unit_tests/llms/test_bedrock.py b/libs/langchain/tests/unit_tests/llms/test_bedrock.py new file mode 100644 index 0000000000..985011c277 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_bedrock.py @@ -0,0 +1,252 @@ +import pytest + +from langchain.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format + +TEST_CASES = { + """Hey""": """ + +Human: Hey + +Assistant:""", + """ + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Human: Hello + +Assistant:""": ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." + ), + """Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant: Hello + +Assistant: Hello""": ALTERNATION_ERROR, + """ + +Human: Hi. + +Assistant: Hi. + +Human: Hi. + +Human: Hi. + +Assistant:""": ALTERNATION_ERROR, + """ +Human: Hello""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Hello +Hello + +Assistant""": """ + +Human: Hello +Hello + +Assistant + +Assistant:""", + """Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Hello + +Human: Hello + +""": """Hello + +Human: Hello + + + +Assistant:""", + """ + +Human: Assistant: Hello""": """ + +Human: + +Assistant: Hello""", + """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""": """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""", + """ +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant.""": """ + +Human: + +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant. + +Assistant:""", + """ + +Human: Human: Hi + +Assistant: Hi""": ALTERNATION_ERROR, + """Human: Hi + +Human: Hi""": ALTERNATION_ERROR, + """ + +Assistant: Hi + +Human: Hi""": """ + +Human: + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi +Human: Hi +Assistant:""": """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Hello. + +Human: Hello. + +Assistant:""": """ + +Hello. + +Human: Hello. + +Assistant:""", +} + + +def test__human_assistant_format() -> None: + for input_text, expected_output in TEST_CASES.items(): + if expected_output == ALTERNATION_ERROR: + with pytest.raises(ValueError): + _human_assistant_format(input_text) + else: + output = _human_assistant_format(input_text) + assert output == expected_output