mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community[patch]: BedrockChat -> Support Titan express as chat model (#15408)
Titan Express model was not supported as a chat model because LangChain messages were not "translated" to a text prompt. Co-authored-by: Guillem Orellana Trullols <guillem.orellana_trullols@siemens.com>
This commit is contained in:
parent
1b9001db47
commit
aad2aa7188
@ -32,6 +32,12 @@ class ChatPromptAdapter:
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
elif provider == "meta":
|
||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||
elif provider == "amazon":
|
||||
prompt = convert_messages_to_prompt_anthropic(
|
||||
messages=messages,
|
||||
human_prompt="\n\nUser:",
|
||||
ai_prompt="\n\nBot:",
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} model does not support chat."
|
||||
|
@ -272,10 +272,12 @@ class BedrockBase(BaseModel, ABC):
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(
|
||||
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept=accept,
|
||||
contentType=contentType,
|
||||
)
|
||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
|
||||
e.__traceback__
|
||||
|
@ -37,17 +37,24 @@ def test_formatting(messages: List[BaseMessage], expected: str) -> None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_anthropic_bedrock() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
["anthropic.claude-v2", "amazon.titan-text-express-v1"],
|
||||
)
|
||||
def test_different_models_bedrock(model_id: str) -> None:
|
||||
provider = model_id.split(".")[0]
|
||||
client = MagicMock()
|
||||
respbody = MagicMock(
|
||||
read=MagicMock(
|
||||
return_value=MagicMock(
|
||||
decode=MagicMock(return_value=b'{"completion":"Hi back"}')
|
||||
)
|
||||
)
|
||||
respbody = MagicMock()
|
||||
if provider == "anthropic":
|
||||
respbody.read.return_value = MagicMock(
|
||||
decode=MagicMock(return_value=b'{"completion":"Hi back"}'),
|
||||
)
|
||||
client.invoke_model.return_value = {"body": respbody}
|
||||
model = BedrockChat(model_id="anthropic.claude-v2", client=client)
|
||||
elif provider == "amazon":
|
||||
respbody.read.return_value = '{"results": [{"outputText": "Hi back"}]}'
|
||||
client.invoke_model.return_value = {"body": respbody}
|
||||
|
||||
model = BedrockChat(model_id=model_id, client=client)
|
||||
|
||||
# should not throw an error
|
||||
model.invoke("hello there")
|
||||
|
Loading…
Reference in New Issue
Block a user