diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index 3f10b09b63..5ec60e8496 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -1,11 +1,25 @@ from __future__ import annotations +import asyncio import json import warnings from abc import ABC -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, +) -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator @@ -128,26 +142,56 @@ class LLMInputOutputAdapter: if not stream: return - if provider not in cls.provider_to_output_key_map: + output_key = cls.provider_to_output_key_map.get(provider, None) + + if not output_key: raise ValueError( f"Unknown streaming response output key for provider: {provider}" ) for event in stream: chunk = event.get("chunk") - if chunk: - chunk_obj = json.loads(chunk.get("bytes").decode()) - if provider == "cohere" and ( - chunk_obj["is_finished"] - or chunk_obj[cls.provider_to_output_key_map[provider]] - == "" - ): - return + if not chunk: + continue - # chunk obj format varies with provider - yield GenerationChunk( - text=chunk_obj[cls.provider_to_output_key_map[provider]] - ) + chunk_obj = json.loads(chunk.get("bytes").decode()) + + if provider == "cohere" and ( + chunk_obj["is_finished"] or chunk_obj[output_key] == "" + ): + return + + yield GenerationChunk(text=chunk_obj[output_key]) + + @classmethod + async def aprepare_output_stream( + cls, provider: str, response: Any, stop: Optional[List[str]] = None + ) -> AsyncIterator[GenerationChunk]: + stream = response.get("body") + + if not stream: + return + + output_key = cls.provider_to_output_key_map.get(provider, None) + + if not output_key: + raise ValueError( + f"Unknown streaming response output key for provider: {provider}" + ) + + for event in stream: + chunk = event.get("chunk") + if not chunk: + continue + + chunk_obj = json.loads(chunk.get("bytes").decode()) + + if provider == "cohere" and ( + chunk_obj["is_finished"] or chunk_obj[output_key] == "" + ): + return + + yield GenerationChunk(text=chunk_obj[output_key]) class BedrockBase(BaseModel, ABC): @@ -332,6 +376,51 @@ class BedrockBase(BaseModel, ABC): if run_manager is not None: run_manager.on_llm_new_token(chunk.text, chunk=chunk) + async def _aprepare_input_and_invoke_stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + provider = self._get_provider() + + if stop: + if provider not in self.provider_stop_sequence_key_name_map: + raise ValueError( + f"Stop sequence key name for {provider} is not supported." + ) + _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + + if provider == "cohere": + _model_kwargs["stream"] = True + + params = {**_model_kwargs, **kwargs} + input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + body = json.dumps(input_body) + + response = await asyncio.get_running_loop().run_in_executor( + None, + lambda: self.client.invoke_model_with_response_stream( + body=body, + modelId=self.model_id, + accept="application/json", + contentType="application/json", + ), + ) + + async for chunk in LLMInputOutputAdapter.aprepare_output_stream( + provider, response, stop + ): + yield chunk + if run_manager is not None and asyncio.iscoroutinefunction( + run_manager.on_llm_new_token + ): + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + elif run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + class Bedrock(LLM, BedrockBase): """Bedrock models. @@ -449,6 +538,65 @@ class Bedrock(LLM, BedrockBase): return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncGenerator[GenerationChunk, None]: + """Call out to Bedrock service with streaming. + + Args: + prompt (str): The prompt to pass into the model + stop (Optional[List[str]], optional): Stop sequences. These will + override any stop sequences in the `model_kwargs` attribute. + Defaults to None. + run_manager (Optional[CallbackManagerForLLMRun], optional): Callback + run managers used to process the output. Defaults to None. + + Yields: + AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields + the streamed responses. + """ + async for chunk in self._aprepare_input_and_invoke_stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Bedrock service model. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = await llm._acall("Tell me a joke.") + """ + + if not self.streaming: + raise ValueError("Streaming must be set to True for async operations. ") + + chunks = [ + chunk.text + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + ] + return "".join(chunks) + def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: return get_num_tokens_anthropic(text) diff --git a/libs/community/tests/unit_tests/llms/test_bedrock.py b/libs/community/tests/unit_tests/llms/test_bedrock.py index 12467b4f42..6001218293 100644 --- a/libs/community/tests/unit_tests/llms/test_bedrock.py +++ b/libs/community/tests/unit_tests/llms/test_bedrock.py @@ -1,6 +1,14 @@ +import json +from typing import AsyncGenerator, Dict +from unittest.mock import MagicMock, patch + import pytest -from langchain_community.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format +from langchain_community.llms.bedrock import ( + ALTERNATION_ERROR, + Bedrock, + _human_assistant_format, +) TEST_CASES = { """Hey""": """ @@ -250,3 +258,51 @@ def test__human_assistant_format() -> None: else: output = _human_assistant_format(input_text) assert output == expected_output + + +# Sample mock streaming response data +MOCK_STREAMING_RESPONSE = [ + {"chunk": {"bytes": b'{"text": "nice"}'}}, + {"chunk": {"bytes": b'{"text": " to meet"}'}}, + {"chunk": {"bytes": b'{"text": " you"}'}}, +] + + +async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]: + for item in MOCK_STREAMING_RESPONSE: + yield item + + +@pytest.mark.asyncio +async def test_bedrock_async_streaming_call() -> None: + # Mock boto3 import + mock_boto3 = MagicMock() + mock_boto3.Session.return_value.client.return_value = ( + MagicMock() + ) # Mocking the client method of the Session object + + with patch.dict( + "sys.modules", {"boto3": mock_boto3} + ): # Mocking boto3 at the top level using patch.dict + # Mock the `Bedrock` class's method that invokes the model + mock_invoke_method = MagicMock(return_value=async_gen_mock_streaming_response()) + with patch.object( + Bedrock, "_aprepare_input_and_invoke_stream", mock_invoke_method + ): + # Instantiate the Bedrock LLM + llm = Bedrock( + client=None, + model_id="anthropic.claude-v2", + streaming=True, + ) + # Call the _astream method + chunks = [ + json.loads(chunk["chunk"]["bytes"])["text"] # type: ignore + async for chunk in llm._astream("Hey, how are you?") + ] + + # Assertions + assert len(chunks) == 3 + assert chunks[0] == "nice" + assert chunks[1] == " to meet" + assert chunks[2] == " you"