feat: add bedrock chat model (#8017)

Replace this comment with:
  - Description: Add Bedrock implementation of Anthropic Claude for Chat
  - Tag maintainer: @hwchase17, @baskaryan
  - Twitter handle: @bwmatson

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10133/head
Benjamin Matson 10 months ago committed by GitHub
parent a7c9bd30d4
commit 58d7d86e51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# Bedrock Chat\n",
"\n",
"[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d51edc81",
"metadata": {},
"outputs": [],
"source": [
"%pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import BedrockChat\n",
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c253883f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"BedrockChat",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",

@ -6,10 +6,6 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
@ -18,7 +14,54 @@ from langchain.schema.messages import (
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def _convert_one_message_to_text(
message: BaseMessage,
human_prompt: str,
ai_prompt: str,
) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{human_prompt} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_anthropic(
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
ai_prompt: str = "\n\nAssistant:",
) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
Returns:
str: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = "".join(
_convert_one_message_to_text(message, human_prompt, ai_prompt)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
def lc_serializable(self) -> bool:
return True
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return "".join(
self._convert_one_message_to_text(message) for message in messages
)
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
messages = messages.copy() # don't mutate the original list
if not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = self._convert_messages_to_text(messages)
return (
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
prompt_params = {}
if self.HUMAN_PROMPT:
prompt_params["human_prompt"] = self.HUMAN_PROMPT
if self.AI_PROMPT:
prompt_params["ai_prompt"] = self.AI_PROMPT
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
def _stream(
self,
@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(messages)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(messages)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,

@ -0,0 +1,98 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
class ChatPromptAdapter:
"""Adapter class to prepare the inputs from Langchain to prompt format
that Chat model expects.
"""
@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
)
return prompt
class BedrockChat(BaseChatModel, BedrockBase):
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "amazon_bedrock_chat"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError(
"""Bedrock doesn't support stream requests at the moment."""
)
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError(
"""Bedrock doesn't support async requests at the moment."""
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Bedrock doesn't support async stream requests at the moment."""
)

@ -1,10 +1,11 @@
import json
from abc import ABC
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
class LLMInputOutputAdapter:
@ -47,33 +48,7 @@ class LLMInputOutputAdapter:
return response_body.get("results")[0].get("outputText")
class Bedrock(LLM):
"""Bedrock models.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Bedrock service.
"""
"""
Example:
.. code-block:: python
from bedrock_langchain.bedrock_llm import BedrockLLM
llm = BedrockLLM(
credentials_profile_name="default",
model_id="amazon.titan-tg1-large"
)
"""
class BedrockBase(BaseModel, ABC):
client: Any #: :meta private:
region_name: Optional[str] = None
@ -99,11 +74,6 @@ class Bedrock(LLM):
endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
@ -151,11 +121,77 @@ class Bedrock(LLM):
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self) -> str:
return self.model_id.split(".")[0]
def _prepare_input_and_invoke(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
try:
response = self.client.invoke_model(
body=body, modelId=self.model_id, accept=accept, contentType=contentType
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
class Bedrock(LLM, BedrockBase):
"""Bedrock models.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Bedrock service.
"""
"""
Example:
.. code-block:: python
from bedrock_langchain.bedrock_llm import BedrockLLM
llm = BedrockLLM(
credentials_profile_name="default",
model_id="amazon.titan-tg1-large"
)
"""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "amazon_bedrock"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _call(
self,
prompt: str,
@ -177,25 +213,7 @@ class Bedrock(LLM):
response = se("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
provider = self.model_id.split(".")[0]
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
try:
response = self.client.invoke_model(
body=body, modelId=self.model_id, accept=accept, contentType=contentType
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
return text

@ -4,11 +4,11 @@ from typing import List
import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import (
ChatGeneration,
LLMResult,
from langchain.chat_models.anthropic import (
ChatAnthropic,
convert_messages_to_prompt_anthropic,
)
from langchain.schema import ChatGeneration, LLMResult
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_formatting() -> None:
chat = ChatAnthropic()
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = chat._convert_messages_to_prompt(chat_messages)
messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = convert_messages_to_prompt_anthropic(messages)
assert result == "\n\nHuman: Hello\n\nAssistant:"
chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
result = chat._convert_messages_to_prompt(chat_messages)
messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
result = convert_messages_to_prompt_anthropic(messages)
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"

Loading…
Cancel
Save