|
|
|
@ -43,6 +43,7 @@ from langchain_core.messages import (
|
|
|
|
|
ToolCall,
|
|
|
|
|
ToolMessage,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.messages.ai import UsageMetadata
|
|
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
|
|
|
from langchain_core.runnables import (
|
|
|
|
@ -653,14 +654,20 @@ class ChatAnthropic(BaseChatModel):
|
|
|
|
|
message_chunk = AIMessageChunk(
|
|
|
|
|
content=message.content,
|
|
|
|
|
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
|
|
|
|
usage_metadata=message.usage_metadata,
|
|
|
|
|
)
|
|
|
|
|
yield ChatGenerationChunk(message=message_chunk)
|
|
|
|
|
else:
|
|
|
|
|
yield cast(ChatGenerationChunk, result.generations[0])
|
|
|
|
|
return
|
|
|
|
|
full_generation_info: dict = {}
|
|
|
|
|
with self._client.messages.stream(**params) as stream:
|
|
|
|
|
for text in stream.text_stream:
|
|
|
|
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
|
|
|
|
chunk, full_generation_info = _make_chat_generation_chunk(
|
|
|
|
|
text,
|
|
|
|
|
stream.current_message_snapshot.model_dump(),
|
|
|
|
|
full_generation_info,
|
|
|
|
|
)
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(text, chunk=chunk)
|
|
|
|
|
yield chunk
|
|
|
|
@ -692,14 +699,20 @@ class ChatAnthropic(BaseChatModel):
|
|
|
|
|
message_chunk = AIMessageChunk(
|
|
|
|
|
content=message.content,
|
|
|
|
|
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
|
|
|
|
usage_metadata=message.usage_metadata,
|
|
|
|
|
)
|
|
|
|
|
yield ChatGenerationChunk(message=message_chunk)
|
|
|
|
|
else:
|
|
|
|
|
yield cast(ChatGenerationChunk, result.generations[0])
|
|
|
|
|
return
|
|
|
|
|
full_generation_info: dict = {}
|
|
|
|
|
async with self._async_client.messages.stream(**params) as stream:
|
|
|
|
|
async for text in stream.text_stream:
|
|
|
|
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
|
|
|
|
chunk, full_generation_info = _make_chat_generation_chunk(
|
|
|
|
|
text,
|
|
|
|
|
stream.current_message_snapshot.model_dump(),
|
|
|
|
|
full_generation_info,
|
|
|
|
|
)
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(text, chunk=chunk)
|
|
|
|
|
yield chunk
|
|
|
|
@ -1064,6 +1077,59 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
|
|
|
|
|
return blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_chat_generation_chunk(
|
|
|
|
|
text: str, message_dump: dict, full_generation_info: dict
|
|
|
|
|
) -> Tuple[ChatGenerationChunk, dict]:
|
|
|
|
|
"""Collect metadata and make ChatGenerationChunk.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: text of the message chunk
|
|
|
|
|
message_dump: dict with metadata of the message chunk
|
|
|
|
|
full_generation_info: dict collecting metadata for full stream
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple with ChatGenerationChunk and updated full_generation_info
|
|
|
|
|
"""
|
|
|
|
|
generation_info = {}
|
|
|
|
|
usage_metadata: Optional[UsageMetadata] = None
|
|
|
|
|
for k, v in message_dump.items():
|
|
|
|
|
if k in ("content", "role", "type") or (
|
|
|
|
|
k in full_generation_info and k not in ("usage", "stop_reason")
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
elif k == "usage":
|
|
|
|
|
input_tokens = v.get("input_tokens", 0)
|
|
|
|
|
output_tokens = v.get("output_tokens", 0)
|
|
|
|
|
if "usage" not in full_generation_info:
|
|
|
|
|
full_generation_info[k] = v
|
|
|
|
|
usage_metadata = UsageMetadata(
|
|
|
|
|
input_tokens=input_tokens,
|
|
|
|
|
output_tokens=output_tokens,
|
|
|
|
|
total_tokens=input_tokens + output_tokens,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
seen_input_tokens = full_generation_info[k].get("input_tokens", 0)
|
|
|
|
|
# Anthropic returns the same input token count for each message in a
|
|
|
|
|
# stream. To avoid double counting, we only count the input tokens
|
|
|
|
|
# once. After that, we set the input tokens to zero.
|
|
|
|
|
new_input_tokens = 0 if seen_input_tokens else input_tokens
|
|
|
|
|
usage_metadata = UsageMetadata(
|
|
|
|
|
input_tokens=new_input_tokens,
|
|
|
|
|
output_tokens=output_tokens,
|
|
|
|
|
total_tokens=new_input_tokens + output_tokens,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
full_generation_info[k] = v
|
|
|
|
|
generation_info[k] = v
|
|
|
|
|
return (
|
|
|
|
|
ChatGenerationChunk(
|
|
|
|
|
message=AIMessageChunk(content=text, usage_metadata=usage_metadata),
|
|
|
|
|
generation_info=generation_info,
|
|
|
|
|
),
|
|
|
|
|
full_generation_info,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
|
|
|
|
|
class ChatAnthropicMessages(ChatAnthropic):
|
|
|
|
|
pass
|
|
|
|
|