community[minor]: Bedrock async methods (#12477)

Description: Added support for asynchronous streaming in the Bedrock
class and corresponding tests.

Primarily:
  async def aprepare_output_stream
    async def _aprepare_input_and_invoke_stream
    async def _astream
    async def _acall

I've ensured that the code adheres to the project's linting and
formatting standards by running make format, make lint, and make test.

Issue: #12054, #11589

Dependencies: None

Tag maintainer: @baskaryan 

Twitter handle: @dominic_lovric

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
This commit is contained in:
DL 2024-01-22 23:44:49 +01:00 committed by GitHub
parent d6275e47f2
commit b9e7f6f38a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 220 additions and 16 deletions

View File

@ -1,11 +1,25 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import warnings import warnings
from abc import ABC from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
)
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
@ -128,26 +142,56 @@ class LLMInputOutputAdapter:
if not stream: if not stream:
return return
if provider not in cls.provider_to_output_key_map: output_key = cls.provider_to_output_key_map.get(provider, None)
if not output_key:
raise ValueError( raise ValueError(
f"Unknown streaming response output key for provider: {provider}" f"Unknown streaming response output key for provider: {provider}"
) )
for event in stream: for event in stream:
chunk = event.get("chunk") chunk = event.get("chunk")
if chunk: if not chunk:
chunk_obj = json.loads(chunk.get("bytes").decode()) continue
if provider == "cohere" and (
chunk_obj["is_finished"]
or chunk_obj[cls.provider_to_output_key_map[provider]]
== "<EOS_TOKEN>"
):
return
# chunk obj format varies with provider chunk_obj = json.loads(chunk.get("bytes").decode())
yield GenerationChunk(
text=chunk_obj[cls.provider_to_output_key_map[provider]] if provider == "cohere" and (
) chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
return
yield GenerationChunk(text=chunk_obj[output_key])
@classmethod
async def aprepare_output_stream(
cls, provider: str, response: Any, stop: Optional[List[str]] = None
) -> AsyncIterator[GenerationChunk]:
stream = response.get("body")
if not stream:
return
output_key = cls.provider_to_output_key_map.get(provider, None)
if not output_key:
raise ValueError(
f"Unknown streaming response output key for provider: {provider}"
)
for event in stream:
chunk = event.get("chunk")
if not chunk:
continue
chunk_obj = json.loads(chunk.get("bytes").decode())
if provider == "cohere" and (
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
return
yield GenerationChunk(text=chunk_obj[output_key])
class BedrockBase(BaseModel, ABC): class BedrockBase(BaseModel, ABC):
@ -332,6 +376,51 @@ class BedrockBase(BaseModel, ABC):
if run_manager is not None: if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _aprepare_input_and_invoke_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
if stop:
if provider not in self.provider_stop_sequence_key_name_map:
raise ValueError(
f"Stop sequence key name for {provider} is not supported."
)
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if provider == "cohere":
_model_kwargs["stream"] = True
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
response = await asyncio.get_running_loop().run_in_executor(
None,
lambda: self.client.invoke_model_with_response_stream(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
),
)
async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
provider, response, stop
):
yield chunk
if run_manager is not None and asyncio.iscoroutinefunction(
run_manager.on_llm_new_token
):
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
elif run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
class Bedrock(LLM, BedrockBase): class Bedrock(LLM, BedrockBase):
"""Bedrock models. """Bedrock models.
@ -449,6 +538,65 @@ class Bedrock(LLM, BedrockBase):
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncGenerator[GenerationChunk, None]:
"""Call out to Bedrock service with streaming.
Args:
prompt (str): The prompt to pass into the model
stop (Optional[List[str]], optional): Stop sequences. These will
override any stop sequences in the `model_kwargs` attribute.
Defaults to None.
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
run managers used to process the output. Defaults to None.
Yields:
AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields
the streamed responses.
"""
async for chunk in self._aprepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Bedrock service model.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = await llm._acall("Tell me a joke.")
"""
if not self.streaming:
raise ValueError("Streaming must be set to True for async operations. ")
chunks = [
chunk.text
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
]
return "".join(chunks)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic: if self._model_is_anthropic:
return get_num_tokens_anthropic(text) return get_num_tokens_anthropic(text)

View File

@ -1,6 +1,14 @@
import json
from typing import AsyncGenerator, Dict
from unittest.mock import MagicMock, patch
import pytest import pytest
from langchain_community.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format from langchain_community.llms.bedrock import (
ALTERNATION_ERROR,
Bedrock,
_human_assistant_format,
)
TEST_CASES = { TEST_CASES = {
"""Hey""": """ """Hey""": """
@ -250,3 +258,51 @@ def test__human_assistant_format() -> None:
else: else:
output = _human_assistant_format(input_text) output = _human_assistant_format(input_text)
assert output == expected_output assert output == expected_output
# Sample mock streaming response data
MOCK_STREAMING_RESPONSE = [
{"chunk": {"bytes": b'{"text": "nice"}'}},
{"chunk": {"bytes": b'{"text": " to meet"}'}},
{"chunk": {"bytes": b'{"text": " you"}'}},
]
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
for item in MOCK_STREAMING_RESPONSE:
yield item
@pytest.mark.asyncio
async def test_bedrock_async_streaming_call() -> None:
# Mock boto3 import
mock_boto3 = MagicMock()
mock_boto3.Session.return_value.client.return_value = (
MagicMock()
) # Mocking the client method of the Session object
with patch.dict(
"sys.modules", {"boto3": mock_boto3}
): # Mocking boto3 at the top level using patch.dict
# Mock the `Bedrock` class's method that invokes the model
mock_invoke_method = MagicMock(return_value=async_gen_mock_streaming_response())
with patch.object(
Bedrock, "_aprepare_input_and_invoke_stream", mock_invoke_method
):
# Instantiate the Bedrock LLM
llm = Bedrock(
client=None,
model_id="anthropic.claude-v2",
streaming=True,
)
# Call the _astream method
chunks = [
json.loads(chunk["chunk"]["bytes"])["text"] # type: ignore
async for chunk in llm._astream("Hey, how are you?")
]
# Assertions
assert len(chunks) == 3
assert chunks[0] == "nice"
assert chunks[1] == " to meet"
assert chunks[2] == " you"