mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
experimental[minor]: adds mixtral wrapper (#17423)
**Description:** Adds a chat wrapper for Mixtral models using the [prompt template](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format). --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4f4300723b
commit
66576948e0
@ -17,10 +17,11 @@ an interface where "chat messages" are the inputs and outputs.
|
||||
AIMessage, BaseMessage, HumanMessage
|
||||
""" # noqa: E501
|
||||
|
||||
from langchain_experimental.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna
|
||||
from langchain_experimental.chat_models.llm_wrapper import (
|
||||
Llama2Chat,
|
||||
Mixtral,
|
||||
Orca,
|
||||
Vicuna,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Llama2Chat",
|
||||
"Orca",
|
||||
"Vicuna",
|
||||
]
|
||||
__all__ = ["Llama2Chat", "Orca", "Vicuna", "Mixtral"]
|
||||
|
@ -148,6 +148,23 @@ class Llama2Chat(ChatWrapper):
|
||||
usr_0_end: str = " [/INST]"
|
||||
|
||||
|
||||
class Mixtral(ChatWrapper):
|
||||
"""See https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mixtral"
|
||||
|
||||
sys_beg: str = "<s>[INST] "
|
||||
sys_end: str = "\n"
|
||||
ai_n_beg: str = " "
|
||||
ai_n_end: str = " </s>"
|
||||
usr_n_beg: str = " [INST] "
|
||||
usr_n_end: str = " [/INST]"
|
||||
usr_0_beg: str = ""
|
||||
usr_0_end: str = " [/INST]"
|
||||
|
||||
|
||||
class Orca(ChatWrapper):
|
||||
"""Wrapper for Orca-style models."""
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_experimental.chat_models import Mixtral
|
||||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_cfg_sys_msg() -> Mixtral:
|
||||
return Mixtral(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))
|
||||
|
||||
|
||||
def test_prompt(model: Mixtral) -> None:
|
||||
messages = [
|
||||
SystemMessage(content="sys-msg"),
|
||||
HumanMessage(content="usr-msg-1"),
|
||||
AIMessage(content="ai-msg-1"),
|
||||
HumanMessage(content="usr-msg-2"),
|
||||
]
|
||||
|
||||
actual = model.predict_messages(messages).content # type: ignore
|
||||
expected = (
|
||||
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
||||
)
|
||||
|
||||
assert actual == expected
|
Loading…
Reference in New Issue
Block a user