Added prompt wrapping for Claude with Bedrock (#11090)

- **Description:** Prompt wrapping requirements have been implemented on
the service side of AWS Bedrock for the Anthropic Claude models to
provide parity between Anthropic's offering and Bedrock's offering. This
overnight change broke most existing implementations of Claude, Bedrock
and Langchain. This PR just steals the the Anthropic LLM implementation
to enforce alias/role wrapping and implements it in the existing
mechanism for building the request body. This has also been tested to
fix the chat_model implementation as well. Happy to answer any further
questions or make changes where necessary to get things patched and up
to PyPi ASAP, TY.
- **Issue:** No issue opened at the moment, though will update when
these roll in.
  - **Dependencies:** None

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11147/head
Jonathan Evans 11 months ago committed by GitHub
parent b87cc8b31e
commit 23065f54c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,52 @@ from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.output import GenerationChunk
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
ALTERNATION_ERROR = (
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
)
def _add_newlines_before_ha(input_text: str) -> str:
new_text = input_text
for word in ["Human:", "Assistant:"]:
new_text = new_text.replace(word, "\n\n" + word)
for i in range(2):
new_text = new_text.replace("\n\n\n" + word, "\n\n" + word)
return new_text
def _human_assistant_format(input_text: str) -> str:
if input_text.count("Human:") == 0 or (
input_text.find("Human:") > input_text.find("Assistant:")
and "Assistant:" in input_text
):
input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION
if input_text.count("Assistant:") == 0:
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
if input_text[: len("Human:")] == "Human:":
input_text = "\n\n" + input_text
input_text = _add_newlines_before_ha(input_text)
count = 0
# track alternation
for i in range(len(input_text)):
if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT:
if count % 2 == 0:
count += 1
else:
raise ValueError(ALTERNATION_ERROR)
if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT:
if count % 2 == 1:
count += 1
else:
raise ValueError(ALTERNATION_ERROR)
if count % 2 == 1: # Only saw Human, no Assistant
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
return input_text
class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
@ -26,7 +72,9 @@ class LLMInputOutputAdapter:
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
) -> Dict[str, Any]:
input_body = {**model_kwargs}
if provider == "anthropic" or provider == "ai21":
if provider == "anthropic":
input_body["prompt"] = _human_assistant_format(prompt)
elif provider == "ai21":
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()

@ -0,0 +1,252 @@
import pytest
from langchain.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format
TEST_CASES = {
"""Hey""": """
Human: Hey
Assistant:""",
"""
Human: Hello
Assistant:""": """
Human: Hello
Assistant:""",
"""Human: Hello
Assistant:""": """
Human: Hello
Assistant:""",
"""
Human: Hello
Assistant:""": """
Human: Hello
Assistant:""",
"""
Human: Human: Hello
Assistant:""": (
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
),
"""Human: Hello
Assistant: Hello
Human: Hello
Assistant:""": """
Human: Hello
Assistant: Hello
Human: Hello
Assistant:""",
"""
Human: Hello
Assistant: Hello
Human: Hello
Assistant:""": """
Human: Hello
Assistant: Hello
Human: Hello
Assistant:""",
"""
Human: Hello
Assistant: Hello
Human: Hello
Assistant: Hello
Assistant: Hello""": ALTERNATION_ERROR,
"""
Human: Hi.
Assistant: Hi.
Human: Hi.
Human: Hi.
Assistant:""": ALTERNATION_ERROR,
"""
Human: Hello""": """
Human: Hello
Assistant:""",
"""
Human: Hello
Hello
Assistant""": """
Human: Hello
Hello
Assistant
Assistant:""",
"""Hello
Assistant:""": """
Human: Hello
Assistant:""",
"""Hello
Human: Hello
""": """Hello
Human: Hello
Assistant:""",
"""
Human: Assistant: Hello""": """
Human:
Assistant: Hello""",
"""
Human: Human
Assistant: Assistant
Human: Assistant
Assistant: Human""": """
Human: Human
Assistant: Assistant
Human: Assistant
Assistant: Human""",
"""
Assistant: Hello there, your name is:
Human.
Human: Hello there, your name is:
Assistant.""": """
Human:
Assistant: Hello there, your name is:
Human.
Human: Hello there, your name is:
Assistant.
Assistant:""",
"""
Human: Human: Hi
Assistant: Hi""": ALTERNATION_ERROR,
"""Human: Hi
Human: Hi""": ALTERNATION_ERROR,
"""
Assistant: Hi
Human: Hi""": """
Human:
Assistant: Hi
Human: Hi
Assistant:""",
"""
Human: Hi
Assistant: Yo
Human: Hey
Assistant: Sup
Human: Hi
Assistant: Hi
Human: Hi
Assistant:""": """
Human: Hi
Assistant: Yo
Human: Hey
Assistant: Sup
Human: Hi
Assistant: Hi
Human: Hi
Assistant:""",
"""
Hello.
Human: Hello.
Assistant:""": """
Hello.
Human: Hello.
Assistant:""",
}
def test__human_assistant_format() -> None:
for input_text, expected_output in TEST_CASES.items():
if expected_output == ALTERNATION_ERROR:
with pytest.raises(ValueError):
_human_assistant_format(input_text)
else:
output = _human_assistant_format(input_text)
assert output == expected_output
Loading…
Cancel
Save