mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community[minor]: Bedrock async methods (#12477)
Description: Added support for asynchronous streaming in the Bedrock class and corresponding tests. Primarily: async def aprepare_output_stream async def _aprepare_input_and_invoke_stream async def _astream async def _acall I've ensured that the code adheres to the project's linting and formatting standards by running make format, make lint, and make test. Issue: #12054, #11589 Dependencies: None Tag maintainer: @baskaryan Twitter handle: @dominic_lovric --------- Co-authored-by: Piyush Jain <piyushjain@duck.com>
This commit is contained in:
parent
d6275e47f2
commit
b9e7f6f38a
@ -1,11 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
@ -128,27 +142,57 @@ class LLMInputOutputAdapter:
|
||||
if not stream:
|
||||
return
|
||||
|
||||
if provider not in cls.provider_to_output_key_map:
|
||||
output_key = cls.provider_to_output_key_map.get(provider, None)
|
||||
|
||||
if not output_key:
|
||||
raise ValueError(
|
||||
f"Unknown streaming response output key for provider: {provider}"
|
||||
)
|
||||
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if chunk:
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
if provider == "cohere" and (
|
||||
chunk_obj["is_finished"]
|
||||
or chunk_obj[cls.provider_to_output_key_map[provider]]
|
||||
== "<EOS_TOKEN>"
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
# chunk obj format varies with provider
|
||||
yield GenerationChunk(
|
||||
text=chunk_obj[cls.provider_to_output_key_map[provider]]
|
||||
yield GenerationChunk(text=chunk_obj[output_key])
|
||||
|
||||
@classmethod
|
||||
async def aprepare_output_stream(
|
||||
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
stream = response.get("body")
|
||||
|
||||
if not stream:
|
||||
return
|
||||
|
||||
output_key = cls.provider_to_output_key_map.get(provider, None)
|
||||
|
||||
if not output_key:
|
||||
raise ValueError(
|
||||
f"Unknown streaming response output key for provider: {provider}"
|
||||
)
|
||||
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
if provider == "cohere" and (
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
yield GenerationChunk(text=chunk_obj[output_key])
|
||||
|
||||
|
||||
class BedrockBase(BaseModel, ABC):
|
||||
"""Base class for Bedrock models."""
|
||||
@ -332,6 +376,51 @@ class BedrockBase(BaseModel, ABC):
|
||||
if run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _aprepare_input_and_invoke_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
provider = self._get_provider()
|
||||
|
||||
if stop:
|
||||
if provider not in self.provider_stop_sequence_key_name_map:
|
||||
raise ValueError(
|
||||
f"Stop sequence key name for {provider} is not supported."
|
||||
)
|
||||
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
|
||||
|
||||
if provider == "cohere":
|
||||
_model_kwargs["stream"] = True
|
||||
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: self.client.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
|
||||
provider, response, stop
|
||||
):
|
||||
yield chunk
|
||||
if run_manager is not None and asyncio.iscoroutinefunction(
|
||||
run_manager.on_llm_new_token
|
||||
):
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
elif run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
|
||||
class Bedrock(LLM, BedrockBase):
|
||||
"""Bedrock models.
|
||||
@ -449,6 +538,65 @@ class Bedrock(LLM, BedrockBase):
|
||||
|
||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[GenerationChunk, None]:
|
||||
"""Call out to Bedrock service with streaming.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to pass into the model
|
||||
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||
override any stop sequences in the `model_kwargs` attribute.
|
||||
Defaults to None.
|
||||
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||
run managers used to process the output. Defaults to None.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields
|
||||
the streamed responses.
|
||||
"""
|
||||
async for chunk in self._aprepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Bedrock service model.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await llm._acall("Tell me a joke.")
|
||||
"""
|
||||
|
||||
if not self.streaming:
|
||||
raise ValueError("Streaming must be set to True for async operations. ")
|
||||
|
||||
chunks = [
|
||||
chunk.text
|
||||
async for chunk in self._astream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
]
|
||||
return "".join(chunks)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
|
@ -1,6 +1,14 @@
|
||||
import json
|
||||
from typing import AsyncGenerator, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format
|
||||
from langchain_community.llms.bedrock import (
|
||||
ALTERNATION_ERROR,
|
||||
Bedrock,
|
||||
_human_assistant_format,
|
||||
)
|
||||
|
||||
TEST_CASES = {
|
||||
"""Hey""": """
|
||||
@ -250,3 +258,51 @@ def test__human_assistant_format() -> None:
|
||||
else:
|
||||
output = _human_assistant_format(input_text)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
# Sample mock streaming response data
|
||||
MOCK_STREAMING_RESPONSE = [
|
||||
{"chunk": {"bytes": b'{"text": "nice"}'}},
|
||||
{"chunk": {"bytes": b'{"text": " to meet"}'}},
|
||||
{"chunk": {"bytes": b'{"text": " you"}'}},
|
||||
]
|
||||
|
||||
|
||||
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
|
||||
for item in MOCK_STREAMING_RESPONSE:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_async_streaming_call() -> None:
|
||||
# Mock boto3 import
|
||||
mock_boto3 = MagicMock()
|
||||
mock_boto3.Session.return_value.client.return_value = (
|
||||
MagicMock()
|
||||
) # Mocking the client method of the Session object
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules", {"boto3": mock_boto3}
|
||||
): # Mocking boto3 at the top level using patch.dict
|
||||
# Mock the `Bedrock` class's method that invokes the model
|
||||
mock_invoke_method = MagicMock(return_value=async_gen_mock_streaming_response())
|
||||
with patch.object(
|
||||
Bedrock, "_aprepare_input_and_invoke_stream", mock_invoke_method
|
||||
):
|
||||
# Instantiate the Bedrock LLM
|
||||
llm = Bedrock(
|
||||
client=None,
|
||||
model_id="anthropic.claude-v2",
|
||||
streaming=True,
|
||||
)
|
||||
# Call the _astream method
|
||||
chunks = [
|
||||
json.loads(chunk["chunk"]["bytes"])["text"] # type: ignore
|
||||
async for chunk in llm._astream("Hey, how are you?")
|
||||
]
|
||||
|
||||
# Assertions
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0] == "nice"
|
||||
assert chunks[1] == " to meet"
|
||||
assert chunks[2] == " you"
|
||||
|
Loading…
Reference in New Issue
Block a user