Amazon Bedrock Support Streaming (#10393)

### Description

- Add support for streaming with `Bedrock` LLM and `BedrockChat` Chat
Model.
- Bedrock as of now supports streaming for the `anthropic.claude-*` and
`amazon.titan-*` models only, hence support for those have been built.
- Also increased the default `max_token_to_sample` for Bedrock
`anthropic` model provider to `256` from `50` to keep in line with the
`Anthropic` defaults.
- Added examples for streaming responses to the bedrock example
notebooks.

**_NOTE:_**: This PR fixes the issues mentioned in #9897 and makes that
PR redundant.
pull/8937/merge
Mukit Momin 10 months ago committed by GitHub
parent 0749a642f5
commit 67c5950df3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -73,13 +73,46 @@
"chat(messages)" "chat(messages)"
] ]
}, },
{
"attachments": {},
"cell_type": "markdown",
"id": "a4a4f4d4",
"metadata": {},
"source": [
"### For BedrockChat with Streaming"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "c253883f", "id": "c253883f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"chat = BedrockChat(\n",
" model_id=\"anthropic.claude-v2\",\n",
" streaming=True,\n",
" callbacks=[StreamingStdOutCallbackHandler()],\n",
" model_kwargs={\"temperature\": 0.1},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e52838",
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
"]\n",
"chat(messages)"
]
} }
], ],
"metadata": { "metadata": {
@ -98,7 +131,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -61,6 +61,46 @@
"\n", "\n",
"conversation.predict(input=\"Hi there!\")" "conversation.predict(input=\"Hi there!\")"
] ]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conversation Chain With Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Bedrock\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"\n",
"llm = Bedrock(\n",
" credentials_profile_name=\"bedrock-admin\",\n",
" model_id=\"amazon.titan-tg1-large\",\n",
" streaming=True,\n",
" callbacks=[StreamingStdOutCallbackHandler()],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"conversation = ConversationChain(\n",
" llm=llm, verbose=True, memory=ConversationBufferMemory()\n",
")\n",
"\n",
"conversation.predict(input=\"Hi there!\")"
]
} }
], ],
"metadata": { "metadata": {

@ -1,6 +0,0 @@
{
"name": "docs",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

@ -8,7 +8,7 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms.bedrock import BedrockBase from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, BaseMessage from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
@ -48,10 +48,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError( provider = self._get_provider()
"""Bedrock doesn't support stream requests at the moment.""" prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
) )
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _astream( def _astream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -70,18 +77,24 @@ class BedrockChat(BaseChatModel, BedrockBase):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
provider = self._get_provider() completion = ""
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
params: Dict[str, Any] = {**kwargs} if self.streaming:
if stop: for chunk in self._stream(messages, stop, run_manager, **kwargs):
params["stop_sequences"] = stop completion += chunk.text
else:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
completion = self._prepare_input_and_invoke( params: Dict[str, Any] = {**kwargs}
prompt=prompt, stop=stop, run_manager=run_manager, **params if stop:
) params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion) message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])

@ -1,11 +1,12 @@
import json import json
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.output import GenerationChunk
class LLMInputOutputAdapter: class LLMInputOutputAdapter:
@ -15,6 +16,11 @@ class LLMInputOutputAdapter:
It also provides helper function to extract It also provides helper function to extract
the generated text from the model response.""" the generated text from the model response."""
provider_to_output_key_map = {
"anthropic": "completion",
"amazon": "outputText",
}
@classmethod @classmethod
def prepare_input( def prepare_input(
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any] cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
@ -30,7 +36,7 @@ class LLMInputOutputAdapter:
input_body["inputText"] = prompt input_body["inputText"] = prompt
if provider == "anthropic" and "max_tokens_to_sample" not in input_body: if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 50 input_body["max_tokens_to_sample"] = 256
return input_body return input_body
@ -47,6 +53,30 @@ class LLMInputOutputAdapter:
else: else:
return response_body.get("results")[0].get("outputText") return response_body.get("results")[0].get("outputText")
@classmethod
def prepare_output_stream(
cls, provider: str, response: Any, stop: Optional[List[str]] = None
) -> Iterator[GenerationChunk]:
stream = response.get("body")
if not stream:
return
if provider not in cls.provider_to_output_key_map:
raise ValueError(
f"Unknown streaming response output key for provider: {provider}"
)
for event in stream:
chunk = event.get("chunk")
if chunk:
chunk_obj = json.loads(chunk.get("bytes").decode())
# chunk obj format varies with provider
yield GenerationChunk(
text=chunk_obj[cls.provider_to_output_key_map[provider]]
)
class BedrockBase(BaseModel, ABC): class BedrockBase(BaseModel, ABC):
client: Any #: :meta private: client: Any #: :meta private:
@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC):
endpoint_url: Optional[str] = None endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint""" """Needed if you don't want to default to us-east-1 endpoint"""
streaming: bool = False
"""Whether to stream the results."""
provider_stop_sequence_key_name_map: Mapping[str, str] = {
"anthropic": "stop_sequences",
"amazon": "stopSequences",
"ai21": "stop_sequences",
}
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment.""" """Validate that AWS credentials to and python package exists in environment."""
@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC):
return text return text
def _prepare_input_and_invoke_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
if stop:
if provider not in self.provider_stop_sequence_key_name_map:
raise ValueError(
f"Stop sequence key name for {provider} is not supported."
)
# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs[
self.provider_stop_sequence_key_name_map.get(provider),
] = stop
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
try:
response = self.client.invoke_model_with_response_stream(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
for chunk in LLMInputOutputAdapter.prepare_output_stream(
provider, response, stop
):
yield chunk
if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
class Bedrock(LLM, BedrockBase): class Bedrock(LLM, BedrockBase):
"""Bedrock models. """Bedrock models.
@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase):
llm = BedrockLLM( llm = BedrockLLM(
credentials_profile_name="default", credentials_profile_name="default",
model_id="amazon.titan-tg1-large" model_id="amazon.titan-tg1-large",
streaming=True
) )
""" """
@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase):
extra = Extra.forbid extra = Extra.forbid
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Call out to Bedrock service with streaming.
Args:
prompt (str): The prompt to pass into the model
stop (Optional[List[str]], optional): Stop sequences. These will
override any stop sequences in the `model_kwargs` attribute.
Defaults to None.
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
run managers used to process the output. Defaults to None.
Returns:
Iterator[GenerationChunk]: Generator that yields the streamed responses.
Yields:
Iterator[GenerationChunk]: Responses from the model.
"""
return self._prepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase):
Example: Example:
.. code-block:: python .. code-block:: python
response = se("Tell me a joke.") response = llm("Tell me a joke.")
""" """
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
return text return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)

Loading…
Cancel
Save