mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
mistralai[patch]: add missing _combine_llm_outputs implementation in ChatMistralAI (#18603)
# Description Implementing `_combine_llm_outputs` to `ChatMistralAI` to override the default implementation in `BaseChatModel` returning `{}`. The implementation is inspired by the one in `ChatOpenAI` from package `langchain-openai`. # Issue None # Dependencies None # Twitter handle None --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
0175906437
commit
ace7b66261
@ -250,6 +250,22 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
rtn = _completion_with_retry(**kwargs)
|
rtn = _completion_with_retry(**kwargs)
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
if token_usage is not None:
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
|
return combined
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate api key, python package exists, temperature, and top_p."""
|
"""Validate api key, python package exists, temperature, and top_p."""
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessageChunk
|
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
from langchain_mistralai.chat_models import ChatMistralAI
|
from langchain_mistralai.chat_models import ChatMistralAI
|
||||||
@ -70,6 +70,50 @@ def test_invoke() -> None:
|
|||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mistralai_llm_output_contains_model_name() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatMistralAI(max_tokens=10)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert llm_result.llm_output["model_name"] == chat.model
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert llm_result.llm_output["model_name"] == chat.model
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mistralai_llm_output_contains_token_usage() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatMistralAI(max_tokens=10)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert "token_usage" in llm_result.llm_output
|
||||||
|
token_usage = llm_result.llm_output["token_usage"]
|
||||||
|
assert "prompt_tokens" in token_usage
|
||||||
|
assert "completion_tokens" in token_usage
|
||||||
|
assert "total_tokens" in token_usage
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mistralai_streaming_llm_output_contains_token_usage() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert "token_usage" in llm_result.llm_output
|
||||||
|
token_usage = llm_result.llm_output["token_usage"]
|
||||||
|
assert "prompt_tokens" in token_usage
|
||||||
|
assert "completion_tokens" in token_usage
|
||||||
|
assert "total_tokens" in token_usage
|
||||||
|
|
||||||
|
|
||||||
def test_structured_output() -> None:
|
def test_structured_output() -> None:
|
||||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
||||||
schema = {
|
schema = {
|
||||||
|
Loading…
Reference in New Issue
Block a user