diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 1ffef5d638..2aab9c3a06 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -8,7 +8,11 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -174,6 +178,9 @@ class ChatAnthropic(BaseChatModel): model_kwargs: Dict[str, Any] = Field(default_factory=dict) + streaming: bool = False + """Whether to use streaming or not.""" + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -271,6 +278,11 @@ class ChatAnthropic(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) params = self._format_params(messages=messages, stop=stop, **kwargs) data = self._client.messages.create(**params) return self._format_output(data) @@ -282,6 +294,11 @@ class ChatAnthropic(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) params = self._format_params(messages=messages, stop=stop, **kwargs) data = await self._async_client.messages.create(**params) return self._format_output(data) diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index f1050f4fea..4347edd7b4 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -184,3 +184,31 @@ def test_anthropic_multimodal() -> None: response = chat.invoke(messages) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_streaming() -> None: + """Test streaming tokens from Anthropic.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + llm = ChatAnthropicMessages( + model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager + ) + + response = llm.generate([[HumanMessage(content="I'm Pickle Rick")]]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult) + + +async def test_astreaming() -> None: + """Test streaming tokens from Anthropic.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + llm = ChatAnthropicMessages( + model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager + ) + + response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult)