mistralai[patch]: add missing _combine_llm_outputs implementation in ChatMistralAI (#18603)

# Description
Implementing `_combine_llm_outputs` to `ChatMistralAI` to override the
default implementation in `BaseChatModel` returning `{}`. The
implementation is inspired by the one in `ChatOpenAI` from package
`langchain-openai`.
# Issue
None
# Dependencies
None
# Twitter handle
None

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Pierre Véron 2024-03-29 22:43:20 +01:00 committed by GitHub
parent 0175906437
commit ace7b66261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 1 deletions

View File

@ -250,6 +250,22 @@ class ChatMistralAI(BaseChatModel):
rtn = _completion_with_retry(**kwargs) rtn = _completion_with_retry(**kwargs)
return rtn return rtn
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p.""" """Validate api key, python package exists, temperature, and top_p."""

View File

@ -3,7 +3,7 @@
import json import json
from typing import Any from typing import Any
from langchain_core.messages import AIMessageChunk from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_mistralai.chat_models import ChatMistralAI from langchain_mistralai.chat_models import ChatMistralAI
@ -70,6 +70,50 @@ def test_invoke() -> None:
assert isinstance(result.content, str) assert isinstance(result.content, str)
def test_chat_mistralai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model
def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model
def test_chat_mistralai_llm_output_contains_token_usage() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert "prompt_tokens" in token_usage
assert "completion_tokens" in token_usage
assert "total_tokens" in token_usage
def test_chat_mistralai_streaming_llm_output_contains_token_usage() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert "prompt_tokens" in token_usage
assert "completion_tokens" in token_usage
assert "total_tokens" in token_usage
def test_structured_output() -> None: def test_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
schema = { schema = {