mirror of https://github.com/hwchase17/langchain
experimental[minor]: adds mixtral wrapper (#17423)
**Description:** Adds a chat wrapper for Mixtral models using the [prompt template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format). --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/18756/head^2
parent
4f4300723b
commit
66576948e0
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Mixtral
|
||||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_prompt(model: Mixtral) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = (
|
||||
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
||||
)
|
||||
|
||||
assert actual == expected
|
Loading…
Reference in New Issue