community[patch]: BedrockChat -> Support Titan express as chat model (#15408)

Titan Express model was not supported as a chat model because LangChain
messages were not "translated" to a text prompt.

Co-authored-by: Guillem Orellana Trullols <guillem.orellana_trullols@siemens.com>
This commit is contained in:
Guillem Orellana Trullols 2024-01-22 20:37:23 +01:00 committed by GitHub
parent 1b9001db47
commit aad2aa7188
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 11 deletions

View File

@ -32,6 +32,12 @@ class ChatPromptAdapter:
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
elif provider == "amazon":
prompt = convert_messages_to_prompt_anthropic(
messages=messages,
human_prompt="\n\nUser:",
ai_prompt="\n\nBot:",
)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."

View File

@ -272,10 +272,12 @@ class BedrockBase(BaseModel, ABC):
try:
response = self.client.invoke_model(
body=body, modelId=self.model_id, accept=accept, contentType=contentType
body=body,
modelId=self.model_id,
accept=accept,
contentType=contentType,
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
e.__traceback__

View File

@ -37,17 +37,24 @@ def test_formatting(messages: List[BaseMessage], expected: str) -> None:
assert result == expected
def test_anthropic_bedrock() -> None:
@pytest.mark.parametrize(
"model_id",
["anthropic.claude-v2", "amazon.titan-text-express-v1"],
)
def test_different_models_bedrock(model_id: str) -> None:
provider = model_id.split(".")[0]
client = MagicMock()
respbody = MagicMock(
read=MagicMock(
return_value=MagicMock(
decode=MagicMock(return_value=b'{"completion":"Hi back"}')
)
)
respbody = MagicMock()
if provider == "anthropic":
respbody.read.return_value = MagicMock(
decode=MagicMock(return_value=b'{"completion":"Hi back"}'),
)
client.invoke_model.return_value = {"body": respbody}
model = BedrockChat(model_id="anthropic.claude-v2", client=client)
elif provider == "amazon":
respbody.read.return_value = '{"results": [{"outputText": "Hi back"}]}'
client.invoke_model.return_value = {"body": respbody}
model = BedrockChat(model_id=model_id, client=client)
# should not throw an error
model.invoke("hello there")