mistral[patch]: add usage_metadata to (a)invoke and (a)stream (#22781)

pull/22682/head
ccurme 4 weeks ago committed by GitHub
parent 20e3662acf
commit 936aedd10c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -186,9 +186,10 @@ async def acompletion_with_retry(
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_delta: Dict, default_class: Type[BaseMessageChunk]
def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
_delta = chunk["choices"][0]["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
@ -216,10 +217,19 @@ def _convert_delta_to_message_chunk(
pass
else:
tool_call_chunks = []
if token_usage := chunk.get("usage"):
usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
else:
usage_metadata = None
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
@ -484,14 +494,21 @@ class ChatMistralAI(BaseChatModel):
def _create_chat_result(self, response: Dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
finish_reason = res.get("finish_reason")
message = _convert_mistral_chat_message_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
gen = ChatGeneration(
message=_convert_mistral_chat_message_to_message(res["message"]),
message=message,
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@ -525,8 +542,7 @@ class ChatMistralAI(BaseChatModel):
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@ -552,8 +568,7 @@ class ChatMistralAI(BaseChatModel):
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)

@ -392,7 +392,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.0"
version = "0.2.5"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -401,15 +401,12 @@ develop = true
[package.dependencies]
jsonpatch = "^1.33"
langsmith = "^0.1.0"
langsmith = "^0.1.75"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
tenacity = "^8.1.0"
[package.extras]
extended-testing = ["jinja2 (>=3,<4)"]
[package.source]
type = "directory"
url = "../../core"
@ -433,13 +430,13 @@ url = "../../standard-tests"
[[package]]
name = "langsmith"
version = "0.1.58"
version = "0.1.76"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.58-py3-none-any.whl", hash = "sha256:1148cc836ec99d1b2f37cd2fa3014fcac213bb6bad798a2b21bb9111c18c9768"},
{file = "langsmith-0.1.58.tar.gz", hash = "sha256:a5060933c1fb3006b498ec849677993329d7e6138bdc2ec044068ab806e09c39"},
{file = "langsmith-0.1.76-py3-none-any.whl", hash = "sha256:4b8cb14f2233d9673ce9e6e3d545359946d9690a2c1457ab01e7459ec97b964e"},
{file = "langsmith-0.1.76.tar.gz", hash = "sha256:5829f997495c0f9a39f91fe0a57e0cb702e8642e6948945f5bb9f46337db7732"},
]
[package.dependencies]
@ -1051,4 +1048,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "4a5a57d01c791de831f03fb309541443dc8bb51f5068ccfb7bcb77490c2eb6c3"
content-hash = "af4576b4e41d3e01716cff9476d6130dd0c5ef7b98bfd02fefd1f5b730574b6e"

@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.2.0,<0.3"
langchain-core = ">=0.2.2,<0.3"
tokenizers = ">=0.15.1,<1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"

@ -1,11 +1,12 @@
"""Test ChatMistral chat model."""
import json
from typing import Any
from typing import Any, Optional
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.pydantic_v1 import BaseModel
@ -25,8 +26,28 @@ async def test_astream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
async def test_abatch() -> None:

@ -20,14 +20,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
"model": "mistral-large-latest",
"temperature": 0,
}
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)

Loading…
Cancel
Save