diff --git a/docs/extras/integrations/chat/bedrock.ipynb b/docs/extras/integrations/chat/bedrock.ipynb new file mode 100644 index 0000000000..7669fd915e --- /dev/null +++ b/docs/extras/integrations/chat/bedrock.ipynb @@ -0,0 +1,106 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bf733a38-db84-4363-89e2-de6735c37230", + "metadata": {}, + "source": [ + "# Bedrock Chat\n", + "\n", + "[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d51edc81", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import BedrockChat\n", + "from langchain.schema import HumanMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " HumanMessage(\n", + " content=\"Translate this sentence from English to French. I love programming.\"\n", + " )\n", + "]\n", + "chat(messages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c253883f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index ee21a2377e..b03cb77710 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs. from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.anyscale import ChatAnyscale from langchain.chat_models.azure_openai import AzureChatOpenAI +from langchain.chat_models.bedrock import BedrockChat from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm @@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI __all__ = [ "ChatOpenAI", "AzureChatOpenAI", + "BedrockChat", "FakeListChatModel", "PromptLayerChatOpenAI", "ChatAnthropic", diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index 4d00eae4df..a944ab1708 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -6,10 +6,6 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.llms.anthropic import _AnthropicCommon -from langchain.schema import ( - ChatGeneration, - ChatResult, -) from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -18,7 +14,54 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.schema.output import ChatGenerationChunk +from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult + + +def _convert_one_message_to_text( + message: BaseMessage, + human_prompt: str, + ai_prompt: str, +) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"{human_prompt} {message.content}" + elif isinstance(message, AIMessage): + message_text = f"{ai_prompt} {message.content}" + elif isinstance(message, SystemMessage): + message_text = f"{human_prompt} {message.content}" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_anthropic( + messages: List[BaseMessage], + *, + human_prompt: str = "\n\nHuman:", + ai_prompt: str = "\n\nAssistant:", +) -> str: + """Format a list of messages into a full prompt for the Anthropic model + Args: + messages (List[BaseMessage]): List of BaseMessage to combine. + human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:". + ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:". + Returns: + str: Combined string with necessary human_prompt and ai_prompt tags. + """ + + messages = messages.copy() # don't mutate the original list + + if not isinstance(messages[-1], AIMessage): + messages.append(AIMessage(content="")) + + text = "".join( + _convert_one_message_to_text(message, human_prompt, ai_prompt) + for message in messages + ) + + # trim off the trailing ' ' that might come from the "Assistant: " + return text.rstrip() class ChatAnthropic(BaseChatModel, _AnthropicCommon): @@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): def lc_serializable(self) -> bool: return True - def _convert_one_message_to_text(self, message: BaseMessage) -> str: - if isinstance(message, ChatMessage): - message_text = f"\n\n{message.role.capitalize()}: {message.content}" - elif isinstance(message, HumanMessage): - message_text = f"{self.HUMAN_PROMPT} {message.content}" - elif isinstance(message, AIMessage): - message_text = f"{self.AI_PROMPT} {message.content}" - elif isinstance(message, SystemMessage): - message_text = f"{self.HUMAN_PROMPT} {message.content}" - else: - raise ValueError(f"Got unknown type {message}") - return message_text - - def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str: - """Format a list of strings into a single string with necessary newlines. - - Args: - messages (List[BaseMessage]): List of BaseMessage to combine. - - Returns: - str: Combined string with necessary newlines. - """ - return "".join( - self._convert_one_message_to_text(message) for message in messages - ) - def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: """Format a list of messages into a full prompt for the Anthropic model - Args: messages (List[BaseMessage]): List of BaseMessage to combine. - Returns: str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. """ - messages = messages.copy() # don't mutate the original list - - if not self.AI_PROMPT: - raise NameError("Please ensure the anthropic package is loaded") - - if not isinstance(messages[-1], AIMessage): - messages.append(AIMessage(content="")) - text = self._convert_messages_to_text(messages) - return ( - text.rstrip() - ) # trim off the trailing ' ' that might come from the "Assistant: " + prompt_params = {} + if self.HUMAN_PROMPT: + prompt_params["human_prompt"] = self.HUMAN_PROMPT + if self.AI_PROMPT: + prompt_params["ai_prompt"] = self.AI_PROMPT + return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params) def _stream( self, @@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: - prompt = self._convert_messages_to_prompt(messages) + prompt = self._convert_messages_to_prompt( + messages, + ) params: Dict[str, Any] = { "prompt": prompt, **self._default_params, @@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): async for chunk in self._astream(messages, stop, run_manager, **kwargs): completion += chunk.text else: - prompt = self._convert_messages_to_prompt(messages) + prompt = self._convert_messages_to_prompt( + messages, + ) params: Dict[str, Any] = { "prompt": prompt, **self._default_params, diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py new file mode 100644 index 0000000000..a539d6058b --- /dev/null +++ b/libs/langchain/langchain/chat_models/bedrock.py @@ -0,0 +1,98 @@ +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic +from langchain.chat_models.base import BaseChatModel +from langchain.llms.bedrock import BedrockBase +from langchain.pydantic_v1 import Extra +from langchain.schema.messages import AIMessage, BaseMessage +from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult + + +class ChatPromptAdapter: + """Adapter class to prepare the inputs from Langchain to prompt format + that Chat model expects. + """ + + @classmethod + def convert_messages_to_prompt( + cls, provider: str, messages: List[BaseMessage] + ) -> str: + if provider == "anthropic": + prompt = convert_messages_to_prompt_anthropic(messages=messages) + else: + raise NotImplementedError( + f"Provider {provider} model does not support chat." + ) + return prompt + + +class BedrockChat(BaseChatModel, BedrockBase): + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "amazon_bedrock_chat" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + raise NotImplementedError( + """Bedrock doesn't support stream requests at the moment.""" + ) + + def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + raise NotImplementedError( + """Bedrock doesn't support async requests at the moment.""" + ) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + provider = self._get_provider() + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) + + params: Dict[str, Any] = {**kwargs} + if stop: + params["stop_sequences"] = stop + + completion = self._prepare_input_and_invoke( + prompt=prompt, stop=stop, run_manager=run_manager, **params + ) + + message = AIMessage(content=completion) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + raise NotImplementedError( + """Bedrock doesn't support async stream requests at the moment.""" + ) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 75af53a6a5..1805e89281 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -1,10 +1,11 @@ import json +from abc import ABC from typing import Any, Dict, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator +from langchain.pydantic_v1 import BaseModel, Extra, root_validator class LLMInputOutputAdapter: @@ -47,33 +48,7 @@ class LLMInputOutputAdapter: return response_body.get("results")[0].get("outputText") -class Bedrock(LLM): - """Bedrock models. - - To authenticate, the AWS client uses the following methods to - automatically load credentials: - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html - - If a specific credential profile should be used, you must pass - the name of the profile from the ~/.aws/credentials file that is to be used. - - Make sure the credentials / roles used have the required policies to - access the Bedrock service. - """ - - """ - Example: - .. code-block:: python - - from bedrock_langchain.bedrock_llm import BedrockLLM - - llm = BedrockLLM( - credentials_profile_name="default", - model_id="amazon.titan-tg1-large" - ) - - """ - +class BedrockBase(BaseModel, ABC): client: Any #: :meta private: region_name: Optional[str] = None @@ -99,11 +74,6 @@ class Bedrock(LLM): endpoint_url: Optional[str] = None """Needed if you don't want to default to us-east-1 endpoint""" - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" @@ -151,11 +121,77 @@ class Bedrock(LLM): **{"model_kwargs": _model_kwargs}, } + def _get_provider(self) -> str: + return self.model_id.split(".")[0] + + def _prepare_input_and_invoke( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + _model_kwargs = self.model_kwargs or {} + + provider = self._get_provider() + params = {**_model_kwargs, **kwargs} + input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + body = json.dumps(input_body) + accept = "application/json" + contentType = "application/json" + + try: + response = self.client.invoke_model( + body=body, modelId=self.model_id, accept=accept, contentType=contentType + ) + text = LLMInputOutputAdapter.prepare_output(provider, response) + + except Exception as e: + raise ValueError(f"Error raised by bedrock service: {e}") + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + return text + + +class Bedrock(LLM, BedrockBase): + """Bedrock models. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Bedrock service. + """ + + """ + Example: + .. code-block:: python + + from bedrock_langchain.bedrock_llm import BedrockLLM + + llm = BedrockLLM( + credentials_profile_name="default", + model_id="amazon.titan-tg1-large" + ) + + """ + @property def _llm_type(self) -> str: """Return type of llm.""" return "amazon_bedrock" + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + def _call( self, prompt: str, @@ -177,25 +213,7 @@ class Bedrock(LLM): response = se("Tell me a joke.") """ - _model_kwargs = self.model_kwargs or {} - provider = self.model_id.split(".")[0] - params = {**_model_kwargs, **kwargs} - input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) - body = json.dumps(input_body) - accept = "application/json" - contentType = "application/json" - - try: - response = self.client.invoke_model( - body=body, modelId=self.model_id, accept=accept, contentType=contentType - ) - text = LLMInputOutputAdapter.prepare_output(provider, response) - - except Exception as e: - raise ValueError(f"Error raised by bedrock service: {e}") - - if stop is not None: - text = enforce_stop_tokens(text, stop) + text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) return text diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py index 0237b055d1..5e3848d382 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py @@ -4,11 +4,11 @@ from typing import List import pytest from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.anthropic import ChatAnthropic -from langchain.schema import ( - ChatGeneration, - LLMResult, +from langchain.chat_models.anthropic import ( + ChatAnthropic, + convert_messages_to_prompt_anthropic, ) +from langchain.schema import ChatGeneration, LLMResult from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None: def test_formatting() -> None: - chat = ChatAnthropic() - - chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")] - result = chat._convert_messages_to_prompt(chat_messages) + messages: List[BaseMessage] = [HumanMessage(content="Hello")] + result = convert_messages_to_prompt_anthropic(messages) assert result == "\n\nHuman: Hello\n\nAssistant:" - chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")] - result = chat._convert_messages_to_prompt(chat_messages) + messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")] + result = convert_messages_to_prompt_anthropic(messages) assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"