From 67c5950df3db3110d54df4fc524d67f9995c9b61 Mon Sep 17 00:00:00 2001 From: Mukit Momin Date: Wed, 20 Sep 2023 14:55:38 -0400 Subject: [PATCH] Amazon Bedrock Support Streaming (#10393) ### Description - Add support for streaming with `Bedrock` LLM and `BedrockChat` Chat Model. - Bedrock as of now supports streaming for the `anthropic.claude-*` and `amazon.titan-*` models only, hence support for those have been built. - Also increased the default `max_token_to_sample` for Bedrock `anthropic` model provider to `256` from `50` to keep in line with the `Anthropic` defaults. - Added examples for streaming responses to the bedrock example notebooks. **_NOTE:_**: This PR fixes the issues mentioned in #9897 and makes that PR redundant. --- docs/extras/integrations/chat/bedrock.ipynb | 39 +++++- docs/extras/integrations/llms/bedrock.ipynb | 40 ++++++ docs/package-lock.json | 6 - .../langchain/chat_models/bedrock.py | 39 ++++-- libs/langchain/langchain/llms/bedrock.py | 128 +++++++++++++++++- 5 files changed, 224 insertions(+), 28 deletions(-) delete mode 100644 docs/package-lock.json diff --git a/docs/extras/integrations/chat/bedrock.ipynb b/docs/extras/integrations/chat/bedrock.ipynb index 7669fd915e..856fdf2045 100644 --- a/docs/extras/integrations/chat/bedrock.ipynb +++ b/docs/extras/integrations/chat/bedrock.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", "metadata": { "tags": [] @@ -73,13 +73,46 @@ "chat(messages)" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a4a4f4d4", + "metadata": {}, + "source": [ + "### For BedrockChat with Streaming" + ] + }, { "cell_type": "code", "execution_count": null, "id": "c253883f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "\n", + "chat = BedrockChat(\n", + " model_id=\"anthropic.claude-v2\",\n", + " streaming=True,\n", + " callbacks=[StreamingStdOutCallbackHandler()],\n", + " model_kwargs={\"temperature\": 0.1},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9e52838", + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " HumanMessage(\n", + " content=\"Translate this sentence from English to French. I love programming.\"\n", + " )\n", + "]\n", + "chat(messages)" + ] } ], "metadata": { @@ -98,7 +131,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/extras/integrations/llms/bedrock.ipynb b/docs/extras/integrations/llms/bedrock.ipynb index 06ae9f4cee..8f1cbfb4d3 100644 --- a/docs/extras/integrations/llms/bedrock.ipynb +++ b/docs/extras/integrations/llms/bedrock.ipynb @@ -61,6 +61,46 @@ "\n", "conversation.predict(input=\"Hi there!\")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Conversation Chain With Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import Bedrock\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "\n", + "\n", + "llm = Bedrock(\n", + " credentials_profile_name=\"bedrock-admin\",\n", + " model_id=\"amazon.titan-tg1-large\",\n", + " streaming=True,\n", + " callbacks=[StreamingStdOutCallbackHandler()],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "conversation = ConversationChain(\n", + " llm=llm, verbose=True, memory=ConversationBufferMemory()\n", + ")\n", + "\n", + "conversation.predict(input=\"Hi there!\")" + ] } ], "metadata": { diff --git a/docs/package-lock.json b/docs/package-lock.json deleted file mode 100644 index 6ab682576c..0000000000 --- a/docs/package-lock.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "docs", - "lockfileVersion": 3, - "requires": true, - "packages": {} -} diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py index a539d6058b..139f1f55b8 100644 --- a/libs/langchain/langchain/chat_models/bedrock.py +++ b/libs/langchain/langchain/chat_models/bedrock.py @@ -8,7 +8,7 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic from langchain.chat_models.base import BaseChatModel from langchain.llms.bedrock import BedrockBase from langchain.pydantic_v1 import Extra -from langchain.schema.messages import AIMessage, BaseMessage +from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult @@ -48,10 +48,17 @@ class BedrockChat(BaseChatModel, BedrockBase): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - raise NotImplementedError( - """Bedrock doesn't support stream requests at the moment.""" + provider = self._get_provider() + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages ) + for chunk in self._prepare_input_and_invoke_stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + delta = chunk.text + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + def _astream( self, messages: List[BaseMessage], @@ -70,18 +77,24 @@ class BedrockChat(BaseChatModel, BedrockBase): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - provider = self._get_provider() - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) + completion = "" - params: Dict[str, Any] = {**kwargs} - if stop: - params["stop_sequences"] = stop + if self.streaming: + for chunk in self._stream(messages, stop, run_manager, **kwargs): + completion += chunk.text + else: + provider = self._get_provider() + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) - completion = self._prepare_input_and_invoke( - prompt=prompt, stop=stop, run_manager=run_manager, **params - ) + params: Dict[str, Any] = {**kwargs} + if stop: + params["stop_sequences"] = stop + + completion = self._prepare_input_and_invoke( + prompt=prompt, stop=stop, run_manager=run_manager, **params + ) message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 1805e89281..7d07f92662 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -1,11 +1,12 @@ import json from abc import ABC -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain.schema.output import GenerationChunk class LLMInputOutputAdapter: @@ -15,6 +16,11 @@ class LLMInputOutputAdapter: It also provides helper function to extract the generated text from the model response.""" + provider_to_output_key_map = { + "anthropic": "completion", + "amazon": "outputText", + } + @classmethod def prepare_input( cls, provider: str, prompt: str, model_kwargs: Dict[str, Any] @@ -30,7 +36,7 @@ class LLMInputOutputAdapter: input_body["inputText"] = prompt if provider == "anthropic" and "max_tokens_to_sample" not in input_body: - input_body["max_tokens_to_sample"] = 50 + input_body["max_tokens_to_sample"] = 256 return input_body @@ -47,6 +53,30 @@ class LLMInputOutputAdapter: else: return response_body.get("results")[0].get("outputText") + @classmethod + def prepare_output_stream( + cls, provider: str, response: Any, stop: Optional[List[str]] = None + ) -> Iterator[GenerationChunk]: + stream = response.get("body") + + if not stream: + return + + if provider not in cls.provider_to_output_key_map: + raise ValueError( + f"Unknown streaming response output key for provider: {provider}" + ) + + for event in stream: + chunk = event.get("chunk") + if chunk: + chunk_obj = json.loads(chunk.get("bytes").decode()) + + # chunk obj format varies with provider + yield GenerationChunk( + text=chunk_obj[cls.provider_to_output_key_map[provider]] + ) + class BedrockBase(BaseModel, ABC): client: Any #: :meta private: @@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC): endpoint_url: Optional[str] = None """Needed if you don't want to default to us-east-1 endpoint""" + streaming: bool = False + """Whether to stream the results.""" + + provider_stop_sequence_key_name_map: Mapping[str, str] = { + "anthropic": "stop_sequences", + "amazon": "stopSequences", + "ai21": "stop_sequences", + } + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" @@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC): return text + def _prepare_input_and_invoke_stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + provider = self._get_provider() + + if stop: + if provider not in self.provider_stop_sequence_key_name_map: + raise ValueError( + f"Stop sequence key name for {provider} is not supported." + ) + + # stop sequence from _generate() overrides + # stop sequences in the class attribute + _model_kwargs[ + self.provider_stop_sequence_key_name_map.get(provider), + ] = stop + + params = {**_model_kwargs, **kwargs} + input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + body = json.dumps(input_body) + + try: + response = self.client.invoke_model_with_response_stream( + body=body, + modelId=self.model_id, + accept="application/json", + contentType="application/json", + ) + except Exception as e: + raise ValueError(f"Error raised by bedrock service: {e}") + + for chunk in LLMInputOutputAdapter.prepare_output_stream( + provider, response, stop + ): + yield chunk + if run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + class Bedrock(LLM, BedrockBase): """Bedrock models. @@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase): llm = BedrockLLM( credentials_profile_name="default", - model_id="amazon.titan-tg1-large" + model_id="amazon.titan-tg1-large", + streaming=True ) """ @@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase): extra = Extra.forbid + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Call out to Bedrock service with streaming. + + Args: + prompt (str): The prompt to pass into the model + stop (Optional[List[str]], optional): Stop sequences. These will + override any stop sequences in the `model_kwargs` attribute. + Defaults to None. + run_manager (Optional[CallbackManagerForLLMRun], optional): Callback + run managers used to process the output. Defaults to None. + + Returns: + Iterator[GenerationChunk]: Generator that yields the streamed responses. + + Yields: + Iterator[GenerationChunk]: Responses from the model. + """ + return self._prepare_input_and_invoke_stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + def _call( self, prompt: str, @@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase): Example: .. code-block:: python - response = se("Tell me a joke.") + response = llm("Tell me a joke.") """ - text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion - return text + return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)