|
|
|
@ -186,9 +186,10 @@ async def acompletion_with_retry(
|
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_delta_to_message_chunk(
|
|
|
|
|
_delta: Dict, default_class: Type[BaseMessageChunk]
|
|
|
|
|
def _convert_chunk_to_message_chunk(
|
|
|
|
|
chunk: Dict, default_class: Type[BaseMessageChunk]
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
_delta = chunk["choices"][0]["delta"]
|
|
|
|
|
role = _delta.get("role")
|
|
|
|
|
content = _delta.get("content") or ""
|
|
|
|
|
if role == "user" or default_class == HumanMessageChunk:
|
|
|
|
@ -216,10 +217,19 @@ def _convert_delta_to_message_chunk(
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
tool_call_chunks = []
|
|
|
|
|
if token_usage := chunk.get("usage"):
|
|
|
|
|
usage_metadata = {
|
|
|
|
|
"input_tokens": token_usage.get("prompt_tokens", 0),
|
|
|
|
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
|
|
|
|
"total_tokens": token_usage.get("total_tokens", 0),
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
usage_metadata = None
|
|
|
|
|
return AIMessageChunk(
|
|
|
|
|
content=content,
|
|
|
|
|
additional_kwargs=additional_kwargs,
|
|
|
|
|
tool_call_chunks=tool_call_chunks,
|
|
|
|
|
usage_metadata=usage_metadata,
|
|
|
|
|
)
|
|
|
|
|
elif role == "system" or default_class == SystemMessageChunk:
|
|
|
|
|
return SystemMessageChunk(content=content)
|
|
|
|
@ -484,14 +494,21 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
def _create_chat_result(self, response: Dict) -> ChatResult:
|
|
|
|
|
generations = []
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
finish_reason = res.get("finish_reason")
|
|
|
|
|
message = _convert_mistral_chat_message_to_message(res["message"])
|
|
|
|
|
if token_usage and isinstance(message, AIMessage):
|
|
|
|
|
message.usage_metadata = {
|
|
|
|
|
"input_tokens": token_usage.get("prompt_tokens", 0),
|
|
|
|
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
|
|
|
|
"total_tokens": token_usage.get("total_tokens", 0),
|
|
|
|
|
}
|
|
|
|
|
gen = ChatGeneration(
|
|
|
|
|
message=_convert_mistral_chat_message_to_message(res["message"]),
|
|
|
|
|
message=message,
|
|
|
|
|
generation_info={"finish_reason": finish_reason},
|
|
|
|
|
)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
|
|
|
|
|
llm_output = {"token_usage": token_usage, "model": self.model}
|
|
|
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
@ -525,8 +542,7 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
):
|
|
|
|
|
if len(chunk["choices"]) == 0:
|
|
|
|
|
continue
|
|
|
|
|
delta = chunk["choices"][0]["delta"]
|
|
|
|
|
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
|
|
|
|
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
|
|
|
|
# make future chunks same type as first chunk
|
|
|
|
|
default_chunk_class = new_chunk.__class__
|
|
|
|
|
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
|
|
|
@ -552,8 +568,7 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
):
|
|
|
|
|
if len(chunk["choices"]) == 0:
|
|
|
|
|
continue
|
|
|
|
|
delta = chunk["choices"][0]["delta"]
|
|
|
|
|
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
|
|
|
|
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
|
|
|
|
# make future chunks same type as first chunk
|
|
|
|
|
default_chunk_class = new_chunk.__class__
|
|
|
|
|
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
|
|
|
|