mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community[minor]: Bedrock async methods (#12477)
Description: Added support for asynchronous streaming in the Bedrock class and corresponding tests. Primarily: async def aprepare_output_stream async def _aprepare_input_and_invoke_stream async def _astream async def _acall I've ensured that the code adheres to the project's linting and formatting standards by running make format, make lint, and make test. Issue: #12054, #11589 Dependencies: None Tag maintainer: @baskaryan Twitter handle: @dominic_lovric --------- Co-authored-by: Piyush Jain <piyushjain@duck.com>
This commit is contained in:
parent
d6275e47f2
commit
b9e7f6f38a
@ -1,11 +1,25 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.outputs import GenerationChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
@ -128,26 +142,56 @@ class LLMInputOutputAdapter:
|
|||||||
if not stream:
|
if not stream:
|
||||||
return
|
return
|
||||||
|
|
||||||
if provider not in cls.provider_to_output_key_map:
|
output_key = cls.provider_to_output_key_map.get(provider, None)
|
||||||
|
|
||||||
|
if not output_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown streaming response output key for provider: {provider}"
|
f"Unknown streaming response output key for provider: {provider}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in stream:
|
for event in stream:
|
||||||
chunk = event.get("chunk")
|
chunk = event.get("chunk")
|
||||||
if chunk:
|
if not chunk:
|
||||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
continue
|
||||||
if provider == "cohere" and (
|
|
||||||
chunk_obj["is_finished"]
|
|
||||||
or chunk_obj[cls.provider_to_output_key_map[provider]]
|
|
||||||
== "<EOS_TOKEN>"
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# chunk obj format varies with provider
|
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||||
yield GenerationChunk(
|
|
||||||
text=chunk_obj[cls.provider_to_output_key_map[provider]]
|
if provider == "cohere" and (
|
||||||
)
|
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
yield GenerationChunk(text=chunk_obj[output_key])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def aprepare_output_stream(
|
||||||
|
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
stream = response.get("body")
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
return
|
||||||
|
|
||||||
|
output_key = cls.provider_to_output_key_map.get(provider, None)
|
||||||
|
|
||||||
|
if not output_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown streaming response output key for provider: {provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in stream:
|
||||||
|
chunk = event.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||||
|
|
||||||
|
if provider == "cohere" and (
|
||||||
|
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
yield GenerationChunk(text=chunk_obj[output_key])
|
||||||
|
|
||||||
|
|
||||||
class BedrockBase(BaseModel, ABC):
|
class BedrockBase(BaseModel, ABC):
|
||||||
@ -332,6 +376,51 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
if run_manager is not None:
|
if run_manager is not None:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
async def _aprepare_input_and_invoke_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
provider = self._get_provider()
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
if provider not in self.provider_stop_sequence_key_name_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Stop sequence key name for {provider} is not supported."
|
||||||
|
)
|
||||||
|
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
_model_kwargs["stream"] = True
|
||||||
|
|
||||||
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
response = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.client.invoke_model_with_response_stream(
|
||||||
|
body=body,
|
||||||
|
modelId=self.model_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
|
||||||
|
provider, response, stop
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if run_manager is not None and asyncio.iscoroutinefunction(
|
||||||
|
run_manager.on_llm_new_token
|
||||||
|
):
|
||||||
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
elif run_manager is not None:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
|
||||||
class Bedrock(LLM, BedrockBase):
|
class Bedrock(LLM, BedrockBase):
|
||||||
"""Bedrock models.
|
"""Bedrock models.
|
||||||
@ -449,6 +538,65 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
|
|
||||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncGenerator[GenerationChunk, None]:
|
||||||
|
"""Call out to Bedrock service with streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to pass into the model
|
||||||
|
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||||
|
override any stop sequences in the `model_kwargs` attribute.
|
||||||
|
Defaults to None.
|
||||||
|
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||||
|
run managers used to process the output. Defaults to None.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields
|
||||||
|
the streamed responses.
|
||||||
|
"""
|
||||||
|
async for chunk in self._aprepare_input_and_invoke_stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to Bedrock service model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = await llm._acall("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.streaming:
|
||||||
|
raise ValueError("Streaming must be set to True for async operations. ")
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
chunk.text
|
||||||
|
async for chunk in self._astream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return "".join(chunks)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
if self._model_is_anthropic:
|
if self._model_is_anthropic:
|
||||||
return get_num_tokens_anthropic(text)
|
return get_num_tokens_anthropic(text)
|
||||||
|
@ -1,6 +1,14 @@
|
|||||||
|
import json
|
||||||
|
from typing import AsyncGenerator, Dict
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_community.llms.bedrock import ALTERNATION_ERROR, _human_assistant_format
|
from langchain_community.llms.bedrock import (
|
||||||
|
ALTERNATION_ERROR,
|
||||||
|
Bedrock,
|
||||||
|
_human_assistant_format,
|
||||||
|
)
|
||||||
|
|
||||||
TEST_CASES = {
|
TEST_CASES = {
|
||||||
"""Hey""": """
|
"""Hey""": """
|
||||||
@ -250,3 +258,51 @@ def test__human_assistant_format() -> None:
|
|||||||
else:
|
else:
|
||||||
output = _human_assistant_format(input_text)
|
output = _human_assistant_format(input_text)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
# Sample mock streaming response data
|
||||||
|
MOCK_STREAMING_RESPONSE = [
|
||||||
|
{"chunk": {"bytes": b'{"text": "nice"}'}},
|
||||||
|
{"chunk": {"bytes": b'{"text": " to meet"}'}},
|
||||||
|
{"chunk": {"bytes": b'{"text": " you"}'}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
|
||||||
|
for item in MOCK_STREAMING_RESPONSE:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_async_streaming_call() -> None:
|
||||||
|
# Mock boto3 import
|
||||||
|
mock_boto3 = MagicMock()
|
||||||
|
mock_boto3.Session.return_value.client.return_value = (
|
||||||
|
MagicMock()
|
||||||
|
) # Mocking the client method of the Session object
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules", {"boto3": mock_boto3}
|
||||||
|
): # Mocking boto3 at the top level using patch.dict
|
||||||
|
# Mock the `Bedrock` class's method that invokes the model
|
||||||
|
mock_invoke_method = MagicMock(return_value=async_gen_mock_streaming_response())
|
||||||
|
with patch.object(
|
||||||
|
Bedrock, "_aprepare_input_and_invoke_stream", mock_invoke_method
|
||||||
|
):
|
||||||
|
# Instantiate the Bedrock LLM
|
||||||
|
llm = Bedrock(
|
||||||
|
client=None,
|
||||||
|
model_id="anthropic.claude-v2",
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
# Call the _astream method
|
||||||
|
chunks = [
|
||||||
|
json.loads(chunk["chunk"]["bytes"])["text"] # type: ignore
|
||||||
|
async for chunk in llm._astream("Hey, how are you?")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert len(chunks) == 3
|
||||||
|
assert chunks[0] == "nice"
|
||||||
|
assert chunks[1] == " to meet"
|
||||||
|
assert chunks[2] == " you"
|
||||||
|
Loading…
Reference in New Issue
Block a user