langchain_mistralai[patch]: Invoke callback prior to yielding token (#16986)

- **Description:** Invoke callback prior to yielding token in stream and
astream methods for ChatMistralAI.
- **Issue:** https://github.com/langchain-ai/langchain/issues/16913
This commit is contained in:
ccurme 2024-02-03 19:30:50 -05:00 committed by GitHub
parent 267e71606e
commit 0826d87ecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 2 deletions

View File

@ -317,9 +317,9 @@ class ChatMistralAI(BaseChatModel):
continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
yield ChatGenerationChunk(message=chunk)
async def _astream(
self,
@ -342,9 +342,9 @@ class ChatMistralAI(BaseChatModel):
continue
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
yield ChatGenerationChunk(message=chunk)
async def _agenerate(
self,

View File

@ -1,7 +1,10 @@
"""Test MistralAI Chat API wrapper."""
import os
from typing import Any, AsyncGenerator, Generator
from unittest.mock import patch
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
@ -12,6 +15,11 @@ from langchain_core.messages import (
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.models.chat_completion import ( # type: ignore[import]
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from mistralai.models.chat_completion import (
ChatMessage as MistralChatMessage,
)
@ -63,3 +71,50 @@ def test_convert_message_to_mistral_chat_message(
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id="abc123",
model="fake_model",
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=token),
finish_reason=None,
)
],
)
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
for token in ["Hello", " how", " can", " I", " help", "?"]:
yield _make_completion_response_from_token(token)
class MyCustomHandler(BaseCallbackHandler):
last_token: str = ""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.last_token = token
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream)
def test_stream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])
for token in chat.stream("Hello"):
assert callback.last_token == token.content
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream)
async def test_astream_with_callback() -> None:
callback = MyCustomHandler()
chat = ChatMistralAI(callbacks=[callback])
async for token in chat.astream("Hello"):
assert callback.last_token == token.content