|
|
|
@ -49,7 +49,7 @@ def model_cfg_sys_msg() -> Llama2Chat:
|
|
|
|
|
def test_default_system_message(model: Llama2Chat) -> None:
|
|
|
|
|
messages = [HumanMessage(content="usr-msg-1")]
|
|
|
|
|
|
|
|
|
|
actual = model.predict_messages(messages).content # type: ignore
|
|
|
|
|
actual = model.invoke(messages).content # type: ignore
|
|
|
|
|
expected = (
|
|
|
|
|
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
|
|
|
|
)
|
|
|
|
@ -62,7 +62,7 @@ def test_configured_system_message(
|
|
|
|
|
) -> None:
|
|
|
|
|
messages = [HumanMessage(content="usr-msg-1")]
|
|
|
|
|
|
|
|
|
|
actual = model_cfg_sys_msg.predict_messages(messages).content # type: ignore
|
|
|
|
|
actual = model_cfg_sys_msg.invoke(messages).content # type: ignore
|
|
|
|
|
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
|
|
|
|
|
|
|
|
|
assert actual == expected
|
|
|
|
@ -73,7 +73,7 @@ async def test_configured_system_message_async(
|
|
|
|
|
) -> None:
|
|
|
|
|
messages = [HumanMessage(content="usr-msg-1")]
|
|
|
|
|
|
|
|
|
|
actual = await model_cfg_sys_msg.apredict_messages(messages) # type: ignore
|
|
|
|
|
actual = await model_cfg_sys_msg.ainvoke(messages) # type: ignore
|
|
|
|
|
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
|
|
|
|
|
|
|
|
|
assert actual.content == expected
|
|
|
|
@ -87,7 +87,7 @@ def test_provided_system_message(
|
|
|
|
|
HumanMessage(content="usr-msg-1"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
actual = model_cfg_sys_msg.predict_messages(messages).content
|
|
|
|
|
actual = model_cfg_sys_msg.invoke(messages).content
|
|
|
|
|
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
|
|
|
|
|
|
|
|
|
assert actual == expected
|
|
|
|
@ -102,7 +102,7 @@ def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
|
|
|
|
|
HumanMessage(content="usr-msg-3"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
actual = model_cfg_sys_msg.predict_messages(messages).content
|
|
|
|
|
actual = model_cfg_sys_msg.invoke(messages).content
|
|
|
|
|
expected = (
|
|
|
|
|
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>"
|
|
|
|
|
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]"
|
|
|
|
@ -113,14 +113,14 @@ def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
|
|
|
|
|
|
|
|
|
|
def test_no_message(model: Llama2Chat) -> None:
|
|
|
|
|
with pytest.raises(ValueError) as info:
|
|
|
|
|
model.predict_messages([])
|
|
|
|
|
model.invoke([])
|
|
|
|
|
|
|
|
|
|
assert info.value.args[0] == "at least one HumanMessage must be provided"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_ai_message_first(model: Llama2Chat) -> None:
|
|
|
|
|
with pytest.raises(ValueError) as info:
|
|
|
|
|
model.predict_messages([AIMessage(content="ai-msg-1")])
|
|
|
|
|
model.invoke([AIMessage(content="ai-msg-1")])
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
|
info.value.args[0]
|
|
|
|
@ -136,7 +136,7 @@ def test_human_ai_messages_not_alternating(model: Llama2Chat) -> None:
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError) as info:
|
|
|
|
|
model.predict_messages(messages) # type: ignore
|
|
|
|
|
model.invoke(messages) # type: ignore
|
|
|
|
|
|
|
|
|
|
assert info.value.args[0] == (
|
|
|
|
|
"messages must be alternating human- and ai-messages, "
|
|
|
|
@ -151,6 +151,6 @@ def test_last_message_not_human_message(model: Llama2Chat) -> None:
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError) as info:
|
|
|
|
|
model.predict_messages(messages)
|
|
|
|
|
model.invoke(messages)
|
|
|
|
|
|
|
|
|
|
assert info.value.args[0] == "last message must be a HumanMessage"
|
|
|
|
|