Amazon Bedrock Support Streaming (#10393)

### Description

- Add support for streaming with `Bedrock` LLM and `BedrockChat` Chat
Model.
- Bedrock as of now supports streaming for the `anthropic.claude-*` and
`amazon.titan-*` models only, hence support for those have been built.
- Also increased the default `max_token_to_sample` for Bedrock
`anthropic` model provider to `256` from `50` to keep in line with the
`Anthropic` defaults.
- Added examples for streaming responses to the bedrock example
notebooks.

**_NOTE:_**: This PR fixes the issues mentioned in #9897 and makes that
PR redundant.
pull/8937/merge
Mukit Momin 10 months ago committed by GitHub
parent 0749a642f5
commit 67c5950df3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
@ -73,13 +73,46 @@
"chat(messages)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a4a4f4d4",
"metadata": {},
"source": [
"### For BedrockChat with Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c253883f",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"chat = BedrockChat(\n",
" model_id=\"anthropic.claude-v2\",\n",
" streaming=True,\n",
" callbacks=[StreamingStdOutCallbackHandler()],\n",
" model_kwargs={\"temperature\": 0.1},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e52838",
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
"]\n",
"chat(messages)"
]
}
],
"metadata": {
@ -98,7 +131,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.9"
}
},
"nbformat": 4,

@ -61,6 +61,46 @@
"\n",
"conversation.predict(input=\"Hi there!\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conversation Chain With Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Bedrock\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"\n",
"llm = Bedrock(\n",
" credentials_profile_name=\"bedrock-admin\",\n",
" model_id=\"amazon.titan-tg1-large\",\n",
" streaming=True,\n",
" callbacks=[StreamingStdOutCallbackHandler()],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"conversation = ConversationChain(\n",
" llm=llm, verbose=True, memory=ConversationBufferMemory()\n",
")\n",
"\n",
"conversation.predict(input=\"Hi there!\")"
]
}
],
"metadata": {

@ -1,6 +0,0 @@
{
"name": "docs",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

@ -8,7 +8,7 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
@ -48,10 +48,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError(
"""Bedrock doesn't support stream requests at the moment."""
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _astream(
self,
messages: List[BaseMessage],
@ -70,18 +77,24 @@ class BedrockChat(BaseChatModel, BedrockBase):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
completion = ""
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

@ -1,11 +1,12 @@
import json
from abc import ABC
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.output import GenerationChunk
class LLMInputOutputAdapter:
@ -15,6 +16,11 @@ class LLMInputOutputAdapter:
It also provides helper function to extract
the generated text from the model response."""
provider_to_output_key_map = {
"anthropic": "completion",
"amazon": "outputText",
}
@classmethod
def prepare_input(
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
@ -30,7 +36,7 @@ class LLMInputOutputAdapter:
input_body["inputText"] = prompt
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 50
input_body["max_tokens_to_sample"] = 256
return input_body
@ -47,6 +53,30 @@ class LLMInputOutputAdapter:
else:
return response_body.get("results")[0].get("outputText")
@classmethod
def prepare_output_stream(
cls, provider: str, response: Any, stop: Optional[List[str]] = None
) -> Iterator[GenerationChunk]:
stream = response.get("body")
if not stream:
return
if provider not in cls.provider_to_output_key_map:
raise ValueError(
f"Unknown streaming response output key for provider: {provider}"
)
for event in stream:
chunk = event.get("chunk")
if chunk:
chunk_obj = json.loads(chunk.get("bytes").decode())
# chunk obj format varies with provider
yield GenerationChunk(
text=chunk_obj[cls.provider_to_output_key_map[provider]]
)
class BedrockBase(BaseModel, ABC):
client: Any #: :meta private:
@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC):
endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint"""
streaming: bool = False
"""Whether to stream the results."""
provider_stop_sequence_key_name_map: Mapping[str, str] = {
"anthropic": "stop_sequences",
"amazon": "stopSequences",
"ai21": "stop_sequences",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC):
return text
def _prepare_input_and_invoke_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
if stop:
if provider not in self.provider_stop_sequence_key_name_map:
raise ValueError(
f"Stop sequence key name for {provider} is not supported."
)
# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs[
self.provider_stop_sequence_key_name_map.get(provider),
] = stop
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
try:
response = self.client.invoke_model_with_response_stream(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
for chunk in LLMInputOutputAdapter.prepare_output_stream(
provider, response, stop
):
yield chunk
if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
class Bedrock(LLM, BedrockBase):
"""Bedrock models.
@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase):
llm = BedrockLLM(
credentials_profile_name="default",
model_id="amazon.titan-tg1-large"
model_id="amazon.titan-tg1-large",
streaming=True
)
"""
@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase):
extra = Extra.forbid
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Call out to Bedrock service with streaming.
Args:
prompt (str): The prompt to pass into the model
stop (Optional[List[str]], optional): Stop sequences. These will
override any stop sequences in the `model_kwargs` attribute.
Defaults to None.
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
run managers used to process the output. Defaults to None.
Returns:
Iterator[GenerationChunk]: Generator that yields the streamed responses.
Yields:
Iterator[GenerationChunk]: Responses from the model.
"""
return self._prepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
def _call(
self,
prompt: str,
@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase):
Example:
.. code-block:: python
response = se("Tell me a joke.")
response = llm("Tell me a joke.")
"""
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
return text
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)

Loading…
Cancel
Save