From 0826d87ecd5f186124af7a26180507c1691c1a72 Mon Sep 17 00:00:00 2001 From: ccurme Date: Sat, 3 Feb 2024 19:30:50 -0500 Subject: [PATCH] langchain_mistralai[patch]: Invoke callback prior to yielding token (#16986) - **Description:** Invoke callback prior to yielding token in stream and astream methods for ChatMistralAI. - **Issue:** https://github.com/langchain-ai/langchain/issues/16913 --- .../langchain_mistralai/chat_models.py | 4 +- .../tests/unit_tests/test_chat_models.py | 55 +++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index a13308e5d5..d4cd631b9b 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -317,9 +317,9 @@ class ChatMistralAI(BaseChatModel): continue chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ - yield ChatGenerationChunk(message=chunk) if run_manager: run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) + yield ChatGenerationChunk(message=chunk) async def _astream( self, @@ -342,9 +342,9 @@ class ChatMistralAI(BaseChatModel): continue chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ - yield ChatGenerationChunk(message=chunk) if run_manager: await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) + yield ChatGenerationChunk(message=chunk) async def _agenerate( self, diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index d965725ba1..8a28a916f5 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -1,7 +1,10 @@ """Test MistralAI Chat API wrapper.""" import os +from typing import Any, AsyncGenerator, Generator +from unittest.mock import patch import pytest +from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.messages import ( AIMessage, BaseMessage, @@ -12,6 +15,11 @@ from langchain_core.messages import ( # TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker. from mistralai.models.chat_completion import ( # type: ignore[import] + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + DeltaMessage, +) +from mistralai.models.chat_completion import ( ChatMessage as MistralChatMessage, ) @@ -63,3 +71,50 @@ def test_convert_message_to_mistral_chat_message( ) -> None: result = _convert_message_to_mistral_chat_message(message) assert result == expected + + +def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse: + return ChatCompletionStreamResponse( + id="abc123", + model="fake_model", + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=token), + finish_reason=None, + ) + ], + ) + + +def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator: + for token in ["Hello", " how", " can", " I", " help", "?"]: + yield _make_completion_response_from_token(token) + + +async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator: + for token in ["Hello", " how", " can", " I", " help", "?"]: + yield _make_completion_response_from_token(token) + + +class MyCustomHandler(BaseCallbackHandler): + last_token: str = "" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + self.last_token = token + + +@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream) +def test_stream_with_callback() -> None: + callback = MyCustomHandler() + chat = ChatMistralAI(callbacks=[callback]) + for token in chat.stream("Hello"): + assert callback.last_token == token.content + + +@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream) +async def test_astream_with_callback() -> None: + callback = MyCustomHandler() + chat = ChatMistralAI(callbacks=[callback]) + async for token in chat.astream("Hello"): + assert callback.last_token == token.content