diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 271941def8..5ae84c6d08 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -186,9 +186,10 @@ async def acompletion_with_retry( return await _completion_with_retry(**kwargs) -def _convert_delta_to_message_chunk( - _delta: Dict, default_class: Type[BaseMessageChunk] +def _convert_chunk_to_message_chunk( + chunk: Dict, default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: + _delta = chunk["choices"][0]["delta"] role = _delta.get("role") content = _delta.get("content") or "" if role == "user" or default_class == HumanMessageChunk: @@ -216,10 +217,19 @@ def _convert_delta_to_message_chunk( pass else: tool_call_chunks = [] + if token_usage := chunk.get("usage"): + usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } + else: + usage_metadata = None return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, + usage_metadata=usage_metadata, ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) @@ -484,14 +494,21 @@ class ChatMistralAI(BaseChatModel): def _create_chat_result(self, response: Dict) -> ChatResult: generations = [] + token_usage = response.get("usage", {}) for res in response["choices"]: finish_reason = res.get("finish_reason") + message = _convert_mistral_chat_message_to_message(res["message"]) + if token_usage and isinstance(message, AIMessage): + message.usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } gen = ChatGeneration( - message=_convert_mistral_chat_message_to_message(res["message"]), + message=message, generation_info={"finish_reason": finish_reason}, ) generations.append(gen) - token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model": self.model} return ChatResult(generations=generations, llm_output=llm_output) @@ -525,8 +542,7 @@ class ChatMistralAI(BaseChatModel): ): if len(chunk["choices"]) == 0: continue - delta = chunk["choices"][0]["delta"] - new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ gen_chunk = ChatGenerationChunk(message=new_chunk) @@ -552,8 +568,7 @@ class ChatMistralAI(BaseChatModel): ): if len(chunk["choices"]) == 0: continue - delta = chunk["choices"][0]["delta"] - new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ gen_chunk = ChatGenerationChunk(message=new_chunk) diff --git a/libs/partners/mistralai/poetry.lock b/libs/partners/mistralai/poetry.lock index 72a47e3235..34d3cd3b61 100644 --- a/libs/partners/mistralai/poetry.lock +++ b/libs/partners/mistralai/poetry.lock @@ -392,7 +392,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.0" +version = "0.2.5" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -401,15 +401,12 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.0" +langsmith = "^0.1.75" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "directory" url = "../../core" @@ -433,13 +430,13 @@ url = "../../standard-tests" [[package]] name = "langsmith" -version = "0.1.58" +version = "0.1.76" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.58-py3-none-any.whl", hash = "sha256:1148cc836ec99d1b2f37cd2fa3014fcac213bb6bad798a2b21bb9111c18c9768"}, - {file = "langsmith-0.1.58.tar.gz", hash = "sha256:a5060933c1fb3006b498ec849677993329d7e6138bdc2ec044068ab806e09c39"}, + {file = "langsmith-0.1.76-py3-none-any.whl", hash = "sha256:4b8cb14f2233d9673ce9e6e3d545359946d9690a2c1457ab01e7459ec97b964e"}, + {file = "langsmith-0.1.76.tar.gz", hash = "sha256:5829f997495c0f9a39f91fe0a57e0cb702e8642e6948945f5bb9f46337db7732"}, ] [package.dependencies] @@ -1051,4 +1048,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3" +content-hash = "af4576b4e41d3e01716cff9476d6130dd0c5ef7b98bfd02fefd1f5b730574b6e" diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 3d1b00519a..8b8e2f075a 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -12,7 +12,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.2.0,<0.3" +langchain-core = ">=0.2.2,<0.3" tokenizers = ">=0.15.1,<1" httpx = ">=0.25.2,<1" httpx-sse = ">=0.3.1,<1" diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 4bf576ac53..30725fa1f4 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,11 +1,12 @@ """Test ChatMistral chat model.""" import json -from typing import Any +from typing import Any, Optional from langchain_core.messages import ( AIMessage, AIMessageChunk, + BaseMessageChunk, HumanMessage, ) from langchain_core.pydantic_v1 import BaseModel @@ -25,8 +26,28 @@ async def test_astream() -> None: """Test streaming tokens from ChatMistralAI.""" llm = ChatMistralAI() + full: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) + full = token if full is None else full + token + if token.usage_metadata is not None: + chunks_with_token_counts += 1 + if chunks_with_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) async def test_abatch() -> None: diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index 7ea8f1bee8..d9b8ff1969 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -20,14 +20,3 @@ class TestMistralStandard(ChatModelIntegrationTests): "model": "mistral-large-latest", "temperature": 0, } - - @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_usage_metadata( - chat_model_class, - chat_model_params, - )