diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 49b7acad19..5538372272 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -32,6 +32,12 @@ class ChatPromptAdapter: prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": prompt = convert_messages_to_prompt_llama(messages=messages) + elif provider == "amazon": + prompt = convert_messages_to_prompt_anthropic( + messages=messages, + human_prompt="\n\nUser:", + ai_prompt="\n\nBot:", + ) else: raise NotImplementedError( f"Provider {provider} model does not support chat." diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index ddc15d5ef1..3f10b09b63 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -272,10 +272,12 @@ class BedrockBase(BaseModel, ABC): try: response = self.client.invoke_model( - body=body, modelId=self.model_id, accept=accept, contentType=contentType + body=body, + modelId=self.model_id, + accept=accept, + contentType=contentType, ) text = LLMInputOutputAdapter.prepare_output(provider, response) - except Exception as e: raise ValueError(f"Error raised by bedrock service: {e}").with_traceback( e.__traceback__ diff --git a/libs/community/tests/unit_tests/chat_models/test_bedrock.py b/libs/community/tests/unit_tests/chat_models/test_bedrock.py index 93f03b8ece..b515c99e5a 100644 --- a/libs/community/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/community/tests/unit_tests/chat_models/test_bedrock.py @@ -37,17 +37,24 @@ def test_formatting(messages: List[BaseMessage], expected: str) -> None: assert result == expected -def test_anthropic_bedrock() -> None: +@pytest.mark.parametrize( + "model_id", + ["anthropic.claude-v2", "amazon.titan-text-express-v1"], +) +def test_different_models_bedrock(model_id: str) -> None: + provider = model_id.split(".")[0] client = MagicMock() - respbody = MagicMock( - read=MagicMock( - return_value=MagicMock( - decode=MagicMock(return_value=b'{"completion":"Hi back"}') - ) + respbody = MagicMock() + if provider == "anthropic": + respbody.read.return_value = MagicMock( + decode=MagicMock(return_value=b'{"completion":"Hi back"}'), ) - ) - client.invoke_model.return_value = {"body": respbody} - model = BedrockChat(model_id="anthropic.claude-v2", client=client) + client.invoke_model.return_value = {"body": respbody} + elif provider == "amazon": + respbody.read.return_value = '{"results": [{"outputText": "Hi back"}]}' + client.invoke_model.return_value = {"body": respbody} + + model = BedrockChat(model_id=model_id, client=client) # should not throw an error model.invoke("hello there")