feat: add bedrock chat model (#8017)

Replace this comment with:
  - Description: Add Bedrock implementation of Anthropic Claude for Chat
  - Tag maintainer: @hwchase17, @baskaryan
  - Twitter handle: @bwmatson

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10133/head
Benjamin Matson 1 year ago committed by GitHub
parent a7c9bd30d4
commit 58d7d86e51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# Bedrock Chat\n",
"\n",
"[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d51edc81",
"metadata": {},
"outputs": [],
"source": [
"%pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import BedrockChat\n",
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c253883f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.anyscale import ChatAnyscale from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.google_palm import ChatGooglePalm
@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [ __all__ = [
"ChatOpenAI", "ChatOpenAI",
"AzureChatOpenAI", "AzureChatOpenAI",
"BedrockChat",
"FakeListChatModel", "FakeListChatModel",
"PromptLayerChatOpenAI", "PromptLayerChatOpenAI",
"ChatAnthropic", "ChatAnthropic",

@ -6,10 +6,6 @@ from langchain.callbacks.manager import (
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -18,7 +14,54 @@ from langchain.schema.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.output import ChatGenerationChunk from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def _convert_one_message_to_text(
message: BaseMessage,
human_prompt: str,
ai_prompt: str,
) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{human_prompt} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_anthropic(
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
ai_prompt: str = "\n\nAssistant:",
) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
Returns:
str: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = "".join(
_convert_one_message_to_text(message, human_prompt, ai_prompt)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
class ChatAnthropic(BaseChatModel, _AnthropicCommon): class ChatAnthropic(BaseChatModel, _AnthropicCommon):
@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return "".join(
self._convert_one_message_to_text(message) for message in messages
)
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model """Format a list of messages into a full prompt for the Anthropic model
Args: Args:
messages (List[BaseMessage]): List of BaseMessage to combine. messages (List[BaseMessage]): List of BaseMessage to combine.
Returns: Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
""" """
messages = messages.copy() # don't mutate the original list prompt_params = {}
if self.HUMAN_PROMPT:
if not self.AI_PROMPT: prompt_params["human_prompt"] = self.HUMAN_PROMPT
raise NameError("Please ensure the anthropic package is loaded") if self.AI_PROMPT:
prompt_params["ai_prompt"] = self.AI_PROMPT
if not isinstance(messages[-1], AIMessage): return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
messages.append(AIMessage(content=""))
text = self._convert_messages_to_text(messages)
return (
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _stream( def _stream(
self, self,
@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for chunk in self._stream(messages, stop, run_manager, **kwargs): for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text completion += chunk.text
else: else:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = { params: Dict[str, Any] = {
"prompt": prompt, "prompt": prompt,
**self._default_params, **self._default_params,
@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for chunk in self._astream(messages, stop, run_manager, **kwargs): async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text completion += chunk.text
else: else:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = { params: Dict[str, Any] = {
"prompt": prompt, "prompt": prompt,
**self._default_params, **self._default_params,

@ -0,0 +1,98 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
class ChatPromptAdapter:
"""Adapter class to prepare the inputs from Langchain to prompt format
that Chat model expects.
"""
@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
)
return prompt
class BedrockChat(BaseChatModel, BedrockBase):
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "amazon_bedrock_chat"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError(
"""Bedrock doesn't support stream requests at the moment."""
)
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError(
"""Bedrock doesn't support async requests at the moment."""
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Bedrock doesn't support async stream requests at the moment."""
)

@ -1,10 +1,11 @@
import json import json
from abc import ABC
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import BaseModel, Extra, root_validator
class LLMInputOutputAdapter: class LLMInputOutputAdapter:
@ -47,33 +48,7 @@ class LLMInputOutputAdapter:
return response_body.get("results")[0].get("outputText") return response_body.get("results")[0].get("outputText")
class Bedrock(LLM): class BedrockBase(BaseModel, ABC):
"""Bedrock models.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Bedrock service.
"""
"""
Example:
.. code-block:: python
from bedrock_langchain.bedrock_llm import BedrockLLM
llm = BedrockLLM(
credentials_profile_name="default",
model_id="amazon.titan-tg1-large"
)
"""
client: Any #: :meta private: client: Any #: :meta private:
region_name: Optional[str] = None region_name: Optional[str] = None
@ -99,11 +74,6 @@ class Bedrock(LLM):
endpoint_url: Optional[str] = None endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint""" """Needed if you don't want to default to us-east-1 endpoint"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment.""" """Validate that AWS credentials to and python package exists in environment."""
@ -151,11 +121,77 @@ class Bedrock(LLM):
**{"model_kwargs": _model_kwargs}, **{"model_kwargs": _model_kwargs},
} }
def _get_provider(self) -> str:
return self.model_id.split(".")[0]
def _prepare_input_and_invoke(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
try:
response = self.client.invoke_model(
body=body, modelId=self.model_id, accept=accept, contentType=contentType
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
class Bedrock(LLM, BedrockBase):
"""Bedrock models.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Bedrock service.
"""
"""
Example:
.. code-block:: python
from bedrock_langchain.bedrock_llm import BedrockLLM
llm = BedrockLLM(
credentials_profile_name="default",
model_id="amazon.titan-tg1-large"
)
"""
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "amazon_bedrock" return "amazon_bedrock"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -177,25 +213,7 @@ class Bedrock(LLM):
response = se("Tell me a joke.") response = se("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {}
provider = self.model_id.split(".")[0] text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
try:
response = self.client.invoke_model(
body=body, modelId=self.model_id, accept=accept, contentType=contentType
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text return text

@ -4,11 +4,11 @@ from typing import List
import pytest import pytest
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.anthropic import (
from langchain.schema import ( ChatAnthropic,
ChatGeneration, convert_messages_to_prompt_anthropic,
LLMResult,
) )
from langchain.schema import ChatGeneration, LLMResult
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_formatting() -> None: def test_formatting() -> None:
chat = ChatAnthropic() messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = convert_messages_to_prompt_anthropic(messages)
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = chat._convert_messages_to_prompt(chat_messages)
assert result == "\n\nHuman: Hello\n\nAssistant:" assert result == "\n\nHuman: Hello\n\nAssistant:"
chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")] messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
result = chat._convert_messages_to_prompt(chat_messages) result = convert_messages_to_prompt_anthropic(messages)
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:" assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"

Loading…
Cancel
Save