|
|
|
@ -58,6 +58,7 @@ from langchain_core.messages import (
|
|
|
|
|
ToolMessage,
|
|
|
|
|
ToolMessageChunk,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.messages.ai import TokenUsage
|
|
|
|
|
from langchain_core.output_parsers import (
|
|
|
|
|
JsonOutputParser,
|
|
|
|
|
PydanticOutputParser,
|
|
|
|
@ -512,24 +513,30 @@ class ChatOpenAI(BaseChatModel):
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
if choice["delta"] is None:
|
|
|
|
|
continue
|
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
|
msg_chunk = _convert_delta_to_message_chunk(
|
|
|
|
|
choice["delta"], default_chunk_class
|
|
|
|
|
)
|
|
|
|
|
if chunk.get("usage"):
|
|
|
|
|
token_usage = TokenUsage(
|
|
|
|
|
prompt_tokens=chunk["usage"]["prompt_tokens"],
|
|
|
|
|
completion_tokens=chunk["usage"]["completion_tokens"],
|
|
|
|
|
)
|
|
|
|
|
msg_chunk.token_usage = token_usage
|
|
|
|
|
generation_info = {}
|
|
|
|
|
if finish_reason := choice.get("finish_reason"):
|
|
|
|
|
generation_info["finish_reason"] = finish_reason
|
|
|
|
|
logprobs = choice.get("logprobs")
|
|
|
|
|
if logprobs:
|
|
|
|
|
generation_info["logprobs"] = logprobs
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
chunk = ChatGenerationChunk(
|
|
|
|
|
message=chunk, generation_info=generation_info or None
|
|
|
|
|
default_chunk_class = msg_chunk.__class__
|
|
|
|
|
gen_chunk = ChatGenerationChunk(
|
|
|
|
|
message=msg_chunk, generation_info=generation_info or None
|
|
|
|
|
)
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(
|
|
|
|
|
chunk.text, chunk=chunk, logprobs=logprobs
|
|
|
|
|
gen_chunk.text, chunk=gen_chunk, logprobs=logprobs
|
|
|
|
|
)
|
|
|
|
|
yield chunk
|
|
|
|
|
yield gen_chunk
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self,
|
|
|
|
@ -574,7 +581,15 @@ class ChatOpenAI(BaseChatModel):
|
|
|
|
|
raise ValueError(response.get("error"))
|
|
|
|
|
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = _convert_dict_to_message(res["message"])
|
|
|
|
|
message: AIMessage = cast(
|
|
|
|
|
AIMessage, _convert_dict_to_message(res["message"])
|
|
|
|
|
)
|
|
|
|
|
if response.get("usage"):
|
|
|
|
|
token_usage = TokenUsage(
|
|
|
|
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
|
|
|
|
completion_tokens=response["usage"]["completion_tokens"],
|
|
|
|
|
)
|
|
|
|
|
message.token_usage = token_usage
|
|
|
|
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
|
|
|
|
if "logprobs" in res:
|
|
|
|
|
generation_info["logprobs"] = res["logprobs"]
|
|
|
|
@ -583,9 +598,8 @@ class ChatOpenAI(BaseChatModel):
|
|
|
|
|
generation_info=generation_info,
|
|
|
|
|
)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
llm_output = {
|
|
|
|
|
"token_usage": token_usage,
|
|
|
|
|
"token_usage": response.get("usage", {}),
|
|
|
|
|
"model_name": self.model_name,
|
|
|
|
|
"system_fingerprint": response.get("system_fingerprint", ""),
|
|
|
|
|
}
|
|
|
|
|