experimental[minor]: adds mixtral wrapper (#17423)

**Description:** Adds a chat wrapper for Mixtral models using the
[prompt
template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format).

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Alexander Dicke 2024-03-09 02:14:23 +01:00 committed by GitHub
parent 4f4300723b
commit 66576948e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 6 deletions

View File

@ -17,10 +17,11 @@ an interface where "chat messages" are the inputs and outputs.
AIMessage, BaseMessage, HumanMessage
""" # noqa: E501
from langchain_experimental.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna
from langchain_experimental.chat_models.llm_wrapper import (
Llama2Chat,
Mixtral,
Orca,
Vicuna,
)
__all__ = [
"Llama2Chat",
"Orca",
"Vicuna",
]
__all__ = ["Llama2Chat", "Orca", "Vicuna", "Mixtral"]

View File

@ -148,6 +148,23 @@ class Llama2Chat(ChatWrapper):
usr_0_end: str = " [/INST]"
class Mixtral(ChatWrapper):
"""See https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format""" # noqa: E501
@property
def _llm_type(self) -> str:
return "mixtral"
sys_beg: str = "<s>[INST] "
sys_end: str = "\n"
ai_n_beg: str = " "
ai_n_end: str = " </s>"
usr_n_beg: str = " [INST] "
usr_n_end: str = " [/INST]"
usr_0_beg: str = ""
usr_0_end: str = " [/INST]"
class Orca(ChatWrapper):
"""Wrapper for Orca-style models."""

View File

@ -0,0 +1,31 @@
import pytest
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_experimental.chat_models import Mixtral
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
@pytest.fixture
def model() -> Mixtral:
return Mixtral(llm=FakeLLM())
@pytest.fixture
def model_cfg_sys_msg() -> Mixtral:
return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
def test_prompt(model: Mixtral) -> None:
messages = [
SystemMessage(content="sys-msg"),
HumanMessage(content="usr-msg-1"),
AIMessage(content="ai-msg-1"),
HumanMessage(content="usr-msg-2"),
]
actual = model.predict_messages(messages).content # type: ignore
expected = (
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
)
assert actual == expected