diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index a1a10c8edd..cf7acd2616 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -250,6 +250,22 @@ class ChatMistralAI(BaseChatModel): rtn = _completion_with_retry(**kwargs) return rtn + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + if token_usage is not None: + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + combined = {"token_usage": overall_token_usage, "model_name": self.model} + return combined + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, and top_p.""" diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index b292abf700..56646ee42f 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -3,7 +3,7 @@ import json from typing import Any -from langchain_core.messages import AIMessageChunk +from langchain_core.messages import AIMessageChunk, HumanMessage from langchain_core.pydantic_v1 import BaseModel from langchain_mistralai.chat_models import ChatMistralAI @@ -70,6 +70,50 @@ def test_invoke() -> None: assert isinstance(result.content, str) +def test_chat_mistralai_llm_output_contains_model_name() -> None: + """Test llm_output contains model_name.""" + chat = ChatMistralAI(max_tokens=10) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model_name"] == chat.model + + +def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None: + """Test llm_output contains model_name.""" + chat = ChatMistralAI(max_tokens=10, streaming=True) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model_name"] == chat.model + + +def test_chat_mistralai_llm_output_contains_token_usage() -> None: + """Test llm_output contains model_name.""" + chat = ChatMistralAI(max_tokens=10) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert "token_usage" in llm_result.llm_output + token_usage = llm_result.llm_output["token_usage"] + assert "prompt_tokens" in token_usage + assert "completion_tokens" in token_usage + assert "total_tokens" in token_usage + + +def test_chat_mistralai_streaming_llm_output_contains_token_usage() -> None: + """Test llm_output contains model_name.""" + chat = ChatMistralAI(max_tokens=10, streaming=True) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert "token_usage" in llm_result.llm_output + token_usage = llm_result.llm_output["token_usage"] + assert "prompt_tokens" in token_usage + assert "completion_tokens" in token_usage + assert "total_tokens" in token_usage + + def test_structured_output() -> None: llm = ChatMistralAI(model="mistral-large-latest", temperature=0) schema = {