anthropic[patch]: streaming param (#18819)

pull/18822/head
Erick Friis 7 months ago committed by GitHub
parent 8c0b215c02
commit a5bcddc738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,7 +8,11 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -174,6 +178,9 @@ class ChatAnthropic(BaseChatModel):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
streaming: bool = False
"""Whether to use streaming or not."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
@ -271,6 +278,11 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = self._client.messages.create(**params)
return self._format_output(data)
@ -282,6 +294,11 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = await self._async_client.messages.create(**params)
return self._format_output(data)

@ -184,3 +184,31 @@ def test_anthropic_multimodal() -> None:
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_streaming() -> None:
"""Test streaming tokens from Anthropic."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = ChatAnthropicMessages(
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
)
response = llm.generate([[HumanMessage(content="I'm Pickle Rick")]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
async def test_astreaming() -> None:
"""Test streaming tokens from Anthropic."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = ChatAnthropicMessages(
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
)
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)

Loading…
Cancel
Save