Add Minimax chat model (#10776)

resolve the merging issues for
https://github.com/langchain-ai/langchain/pull/6757

---------

Co-authored-by: 何涛 <taohe@bytedance.com>
pull/10771/head^2
HeTaoPKU 1 year ago committed by GitHub
parent c656a6b966
commit f505320a73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,70 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# MiniMax\n",
"\n",
"[Minimax](https://api.minimax.chat) is a Chinese startup that provides LLM service for companies and individuals.\n",
"\n",
"This example goes over how to use LangChain to interact with MiniMax Inference for Chat."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"MINIMAX_GROUP_ID\"] = \"MINIMAX_GROUP_ID\"\n",
"os.environ[\"MINIMAX_API_KEY\"] = \"MINIMAX_API_KEY\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import MiniMaxChat\n",
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat = MiniMaxChat()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat(\n",
" [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
" ]\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -94,7 +94,8 @@
"outputs": [],
"source": [
"from langchain.llms import Minimax\n",
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain"
],
"metadata": {
"collapsed": false

@ -108,7 +108,8 @@
"outputs": [],
"source": [
"from langchain.llms import Modal\n",
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain"
]
},
{

@ -17,6 +17,14 @@ See a [usage example](/docs/modules/model_io/models/llms/integrations/minimax.ht
from langchain.llms import Minimax
```
## Chat Models
See a [usage example](/docs/modules/model_io/models/chat/integrations/minimax.html)
```python
from langchain.chat_models import MiniMaxChat
```
## Text Embedding Model
There exists a Minimax Embedding model, which you can access with

@ -29,6 +29,7 @@ from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.konko import ChatKonko
from langchain.chat_models.litellm import ChatLiteLLM
from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
@ -48,6 +49,7 @@ __all__ = [
"ChatVertexAI",
"JinaChat",
"HumanInputChatModel",
"MiniMaxChat",
"ChatAnyscale",
"ChatLiteLLM",
"ErnieBotChat",

@ -0,0 +1,93 @@
"""Wrapper around Minimax chat models."""
import logging
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.minimax import MinimaxCommon
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema import (
AIMessage,
BaseMessage,
ChatResult,
HumanMessage,
)
logger = logging.getLogger(__name__)
def _parse_message(msg_type: str, text: str) -> Dict:
return {"sender_type": msg_type, "text": text}
def _parse_chat_history(history: List[BaseMessage]) -> List:
"""Parse a sequence of messages into history."""
chat_history = []
for message in history:
if isinstance(message, HumanMessage):
chat_history.append(_parse_message("USER", message.content))
if isinstance(message, AIMessage):
chat_history.append(_parse_message("BOT", message.content))
return chat_history
class MiniMaxChat(MinimaxCommon, BaseChatModel):
"""Wrapper around Minimax large language models.
To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and
``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to
the constructor.
Example:
.. code-block:: python
from langchain.chat_models import MiniMaxChat
llm = MiniMaxChat(model_name="abab5-chat")
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if not messages:
raise ValueError(
"You should provide at least one message to start the chat!"
)
history = _parse_chat_history(messages)
payload = self._default_params
payload["messages"] = history
text = self._client.post(payload)
# This is required since the stop are not enforced by the model parameters
return text if stop is None else enforce_stop_tokens(text, stop)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Minimax AI doesn't support async requests at the moment."""
)

@ -15,7 +15,8 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr, root_validator
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ class _MinimaxEndpointClient(BaseModel):
api_key: str
api_url: str
@root_validator(pre=True)
@root_validator(pre=True, allow_reuse=True)
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "api_url" not in values:
host = values["host"]
@ -52,19 +53,8 @@ class _MinimaxEndpointClient(BaseModel):
return response.json()["reply"]
class Minimax(LLM):
"""Wrapper around Minimax large language models.
To use, you should have the environment variable
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.minimax import Minimax
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
minimax_group_id="my-group-id")
"""
_client: _MinimaxEndpointClient = PrivateAttr()
class MinimaxCommon(BaseModel):
_client: _MinimaxEndpointClient
model: str = "abab5.5-chat"
"""Model name to use."""
max_tokens: int = 256
@ -79,11 +69,6 @@ class Minimax(LLM):
minimax_group_id: Optional[str] = None
minimax_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -131,6 +116,19 @@ class Minimax(LLM):
group_id=self.minimax_group_id,
)
class Minimax(MinimaxCommon, LLM):
"""Wrapper around Minimax large language models.
To use, you should have the environment variable
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
. code-block:: python
from langchain.llms.minimax import Minimax
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
minimax_group_id="my-group-id")
"""
def _call(
self,
prompt: str,
@ -150,6 +148,10 @@ class Minimax(LLM):
request = self._default_params
request["messages"] = [{"sender_type": "USER", "text": prompt}]
request.update(kwargs)
response = self._client.post(request)
text = self._client.post(request)
if stop is not None:
# This is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return response
return text

@ -7,3 +7,16 @@ def test_minimax_call() -> None:
llm = Minimax(max_tokens=10)
output = llm("Hello world!")
assert isinstance(output, str)
def test_minimax_call_successful() -> None:
"""Test valid call to minimax."""
llm = Minimax()
output = llm(
"A chain is a serial assembly of connected pieces, called links, \
typically made of metal, with an overall character similar to that\
of a rope in that it is flexible and curved in compression but \
linear, rigid, and load-bearing in tension. A chain may consist\
of two or more links."
)
assert isinstance(output, str)

Loading…
Cancel
Save