diff --git a/docs/extras/integrations/chat/bedrock.ipynb b/docs/extras/integrations/chat/bedrock.ipynb
new file mode 100644
index 0000000000..7669fd915e
--- /dev/null
+++ b/docs/extras/integrations/chat/bedrock.ipynb
@@ -0,0 +1,106 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "bf733a38-db84-4363-89e2-de6735c37230",
+ "metadata": {},
+ "source": [
+ "# Bedrock Chat\n",
+ "\n",
+ "[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d51edc81",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install boto3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.chat_models import BedrockChat\n",
+ "from langchain.schema import HumanMessage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "messages = [\n",
+ " HumanMessage(\n",
+ " content=\"Translate this sentence from English to French. I love programming.\"\n",
+ " )\n",
+ "]\n",
+ "chat(messages)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c253883f",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py
index ee21a2377e..b03cb77710 100644
--- a/libs/langchain/langchain/chat_models/__init__.py
+++ b/libs/langchain/langchain/chat_models/__init__.py
@@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
+from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
@@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
+ "BedrockChat",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",
diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py
index 4d00eae4df..a944ab1708 100644
--- a/libs/langchain/langchain/chat_models/anthropic.py
+++ b/libs/langchain/langchain/chat_models/anthropic.py
@@ -6,10 +6,6 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
-from langchain.schema import (
- ChatGeneration,
- ChatResult,
-)
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
@@ -18,7 +14,54 @@ from langchain.schema.messages import (
HumanMessage,
SystemMessage,
)
-from langchain.schema.output import ChatGenerationChunk
+from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
+
+
+def _convert_one_message_to_text(
+ message: BaseMessage,
+ human_prompt: str,
+ ai_prompt: str,
+) -> str:
+ if isinstance(message, ChatMessage):
+ message_text = f"\n\n{message.role.capitalize()}: {message.content}"
+ elif isinstance(message, HumanMessage):
+ message_text = f"{human_prompt} {message.content}"
+ elif isinstance(message, AIMessage):
+ message_text = f"{ai_prompt} {message.content}"
+ elif isinstance(message, SystemMessage):
+ message_text = f"{human_prompt} {message.content}"
+ else:
+ raise ValueError(f"Got unknown type {message}")
+ return message_text
+
+
+def convert_messages_to_prompt_anthropic(
+ messages: List[BaseMessage],
+ *,
+ human_prompt: str = "\n\nHuman:",
+ ai_prompt: str = "\n\nAssistant:",
+) -> str:
+ """Format a list of messages into a full prompt for the Anthropic model
+ Args:
+ messages (List[BaseMessage]): List of BaseMessage to combine.
+ human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
+ ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
+ Returns:
+ str: Combined string with necessary human_prompt and ai_prompt tags.
+ """
+
+ messages = messages.copy() # don't mutate the original list
+
+ if not isinstance(messages[-1], AIMessage):
+ messages.append(AIMessage(content=""))
+
+ text = "".join(
+ _convert_one_message_to_text(message, human_prompt, ai_prompt)
+ for message in messages
+ )
+
+ # trim off the trailing ' ' that might come from the "Assistant: "
+ return text.rstrip()
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
@@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
def lc_serializable(self) -> bool:
return True
- def _convert_one_message_to_text(self, message: BaseMessage) -> str:
- if isinstance(message, ChatMessage):
- message_text = f"\n\n{message.role.capitalize()}: {message.content}"
- elif isinstance(message, HumanMessage):
- message_text = f"{self.HUMAN_PROMPT} {message.content}"
- elif isinstance(message, AIMessage):
- message_text = f"{self.AI_PROMPT} {message.content}"
- elif isinstance(message, SystemMessage):
- message_text = f"{self.HUMAN_PROMPT} {message.content}"
- else:
- raise ValueError(f"Got unknown type {message}")
- return message_text
-
- def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
- """Format a list of strings into a single string with necessary newlines.
-
- Args:
- messages (List[BaseMessage]): List of BaseMessage to combine.
-
- Returns:
- str: Combined string with necessary newlines.
- """
- return "".join(
- self._convert_one_message_to_text(message) for message in messages
- )
-
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
-
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
-
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
- messages = messages.copy() # don't mutate the original list
-
- if not self.AI_PROMPT:
- raise NameError("Please ensure the anthropic package is loaded")
-
- if not isinstance(messages[-1], AIMessage):
- messages.append(AIMessage(content=""))
- text = self._convert_messages_to_text(messages)
- return (
- text.rstrip()
- ) # trim off the trailing ' ' that might come from the "Assistant: "
+ prompt_params = {}
+ if self.HUMAN_PROMPT:
+ prompt_params["human_prompt"] = self.HUMAN_PROMPT
+ if self.AI_PROMPT:
+ prompt_params["ai_prompt"] = self.AI_PROMPT
+ return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
def _stream(
self,
@@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
- prompt = self._convert_messages_to_prompt(messages)
+ prompt = self._convert_messages_to_prompt(
+ messages,
+ )
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
@@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
- prompt = self._convert_messages_to_prompt(messages)
+ prompt = self._convert_messages_to_prompt(
+ messages,
+ )
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py
new file mode 100644
index 0000000000..a539d6058b
--- /dev/null
+++ b/libs/langchain/langchain/chat_models/bedrock.py
@@ -0,0 +1,98 @@
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
+
+from langchain.callbacks.manager import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
+from langchain.chat_models.base import BaseChatModel
+from langchain.llms.bedrock import BedrockBase
+from langchain.pydantic_v1 import Extra
+from langchain.schema.messages import AIMessage, BaseMessage
+from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
+
+
+class ChatPromptAdapter:
+ """Adapter class to prepare the inputs from Langchain to prompt format
+ that Chat model expects.
+ """
+
+ @classmethod
+ def convert_messages_to_prompt(
+ cls, provider: str, messages: List[BaseMessage]
+ ) -> str:
+ if provider == "anthropic":
+ prompt = convert_messages_to_prompt_anthropic(messages=messages)
+ else:
+ raise NotImplementedError(
+ f"Provider {provider} model does not support chat."
+ )
+ return prompt
+
+
+class BedrockChat(BaseChatModel, BedrockBase):
+ @property
+ def _llm_type(self) -> str:
+ """Return type of chat model."""
+ return "amazon_bedrock_chat"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ raise NotImplementedError(
+ """Bedrock doesn't support stream requests at the moment."""
+ )
+
+ def _astream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[ChatGenerationChunk]:
+ raise NotImplementedError(
+ """Bedrock doesn't support async requests at the moment."""
+ )
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ provider = self._get_provider()
+ prompt = ChatPromptAdapter.convert_messages_to_prompt(
+ provider=provider, messages=messages
+ )
+
+ params: Dict[str, Any] = {**kwargs}
+ if stop:
+ params["stop_sequences"] = stop
+
+ completion = self._prepare_input_and_invoke(
+ prompt=prompt, stop=stop, run_manager=run_manager, **params
+ )
+
+ message = AIMessage(content=completion)
+ return ChatResult(generations=[ChatGeneration(message=message)])
+
+ async def _agenerate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ raise NotImplementedError(
+ """Bedrock doesn't support async stream requests at the moment."""
+ )
diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py
index 75af53a6a5..1805e89281 100644
--- a/libs/langchain/langchain/llms/bedrock.py
+++ b/libs/langchain/langchain/llms/bedrock.py
@@ -1,10 +1,11 @@
import json
+from abc import ABC
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
-from langchain.pydantic_v1 import Extra, root_validator
+from langchain.pydantic_v1 import BaseModel, Extra, root_validator
class LLMInputOutputAdapter:
@@ -47,33 +48,7 @@ class LLMInputOutputAdapter:
return response_body.get("results")[0].get("outputText")
-class Bedrock(LLM):
- """Bedrock models.
-
- To authenticate, the AWS client uses the following methods to
- automatically load credentials:
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
-
- If a specific credential profile should be used, you must pass
- the name of the profile from the ~/.aws/credentials file that is to be used.
-
- Make sure the credentials / roles used have the required policies to
- access the Bedrock service.
- """
-
- """
- Example:
- .. code-block:: python
-
- from bedrock_langchain.bedrock_llm import BedrockLLM
-
- llm = BedrockLLM(
- credentials_profile_name="default",
- model_id="amazon.titan-tg1-large"
- )
-
- """
-
+class BedrockBase(BaseModel, ABC):
client: Any #: :meta private:
region_name: Optional[str] = None
@@ -99,11 +74,6 @@ class Bedrock(LLM):
endpoint_url: Optional[str] = None
"""Needed if you don't want to default to us-east-1 endpoint"""
- class Config:
- """Configuration for this pydantic object."""
-
- extra = Extra.forbid
-
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
@@ -151,11 +121,77 @@ class Bedrock(LLM):
**{"model_kwargs": _model_kwargs},
}
+ def _get_provider(self) -> str:
+ return self.model_id.split(".")[0]
+
+ def _prepare_input_and_invoke(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ _model_kwargs = self.model_kwargs or {}
+
+ provider = self._get_provider()
+ params = {**_model_kwargs, **kwargs}
+ input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
+ body = json.dumps(input_body)
+ accept = "application/json"
+ contentType = "application/json"
+
+ try:
+ response = self.client.invoke_model(
+ body=body, modelId=self.model_id, accept=accept, contentType=contentType
+ )
+ text = LLMInputOutputAdapter.prepare_output(provider, response)
+
+ except Exception as e:
+ raise ValueError(f"Error raised by bedrock service: {e}")
+
+ if stop is not None:
+ text = enforce_stop_tokens(text, stop)
+
+ return text
+
+
+class Bedrock(LLM, BedrockBase):
+ """Bedrock models.
+
+ To authenticate, the AWS client uses the following methods to
+ automatically load credentials:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+
+ If a specific credential profile should be used, you must pass
+ the name of the profile from the ~/.aws/credentials file that is to be used.
+
+ Make sure the credentials / roles used have the required policies to
+ access the Bedrock service.
+ """
+
+ """
+ Example:
+ .. code-block:: python
+
+ from bedrock_langchain.bedrock_llm import BedrockLLM
+
+ llm = BedrockLLM(
+ credentials_profile_name="default",
+ model_id="amazon.titan-tg1-large"
+ )
+
+ """
+
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "amazon_bedrock"
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
def _call(
self,
prompt: str,
@@ -177,25 +213,7 @@ class Bedrock(LLM):
response = se("Tell me a joke.")
"""
- _model_kwargs = self.model_kwargs or {}
- provider = self.model_id.split(".")[0]
- params = {**_model_kwargs, **kwargs}
- input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
- body = json.dumps(input_body)
- accept = "application/json"
- contentType = "application/json"
-
- try:
- response = self.client.invoke_model(
- body=body, modelId=self.model_id, accept=accept, contentType=contentType
- )
- text = LLMInputOutputAdapter.prepare_output(provider, response)
-
- except Exception as e:
- raise ValueError(f"Error raised by bedrock service: {e}")
-
- if stop is not None:
- text = enforce_stop_tokens(text, stop)
+ text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
return text
diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py
index 0237b055d1..5e3848d382 100644
--- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py
+++ b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py
@@ -4,11 +4,11 @@ from typing import List
import pytest
from langchain.callbacks.manager import CallbackManager
-from langchain.chat_models.anthropic import ChatAnthropic
-from langchain.schema import (
- ChatGeneration,
- LLMResult,
+from langchain.chat_models.anthropic import (
+ ChatAnthropic,
+ convert_messages_to_prompt_anthropic,
)
+from langchain.schema import ChatGeneration, LLMResult
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_formatting() -> None:
- chat = ChatAnthropic()
-
- chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
- result = chat._convert_messages_to_prompt(chat_messages)
+ messages: List[BaseMessage] = [HumanMessage(content="Hello")]
+ result = convert_messages_to_prompt_anthropic(messages)
assert result == "\n\nHuman: Hello\n\nAssistant:"
- chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
- result = chat._convert_messages_to_prompt(chat_messages)
+ messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
+ result = convert_messages_to_prompt_anthropic(messages)
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"