import re from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Extra from langchain_community.chat_models.anthropic import ( convert_messages_to_prompt_anthropic, ) from langchain_community.chat_models.meta import convert_messages_to_prompt_llama from langchain_community.llms.bedrock import BedrockBase from langchain_community.utilities.anthropic import ( get_num_tokens_anthropic, get_token_ids_anthropic, ) def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} to a dict for anthropic api { "type": "base64", "media_type": "image/jpeg", "data": "/9j/4AAQSkZJRg...", } And throws an error if it's not a b64 image """ regex = r"^data:(?Pimage/.+);base64,(?P.+)$" match = re.match(regex, image_url) if match is None: raise ValueError( "Anthropic only supports base64-encoded images currently." " Example: data:image/png;base64,'/9j/4AAQSk'..." ) return { "type": "base64", "media_type": match.group("media_type"), "data": match.group("data"), } def _format_anthropic_messages( messages: List[BaseMessage], ) -> Tuple[Optional[str], List[Dict]]: """Format messages for anthropic.""" """ [ { "role": _message_type_lookups[m.type], "content": [_AnthropicMessageContent(text=m.content).dict()], } for m in messages ] """ system: Optional[str] = None formatted_messages: List[Dict] = [] for i, message in enumerate(messages): if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") if not isinstance(message.content, str): raise ValueError( "System message must be a string, " f"instead was: {type(message.content)}" ) system = message.content continue role = _message_type_lookups[message.type] content: Union[str, List[Dict]] if not isinstance(message.content, str): # parse as dict assert isinstance( message.content, list ), "Anthropic message content must be str or list of dicts" # populate content content = [] for item in message.content: if isinstance(item, str): content.append( { "type": "text", "text": item, } ) elif isinstance(item, dict): if "type" not in item: raise ValueError("Dict content item must have a type key") if item["type"] == "image_url": # convert format source = _format_image(item["image_url"]["url"]) content.append( { "type": "image", "source": source, } ) else: content.append(item) else: raise ValueError( f"Content items must be str or dict, instead was: {type(item)}" ) else: content = message.content formatted_messages.append( { "role": role, "content": content, } ) return system, formatted_messages class ChatPromptAdapter: """Adapter class to prepare the inputs from Langchain to prompt format that Chat model expects. """ @classmethod def convert_messages_to_prompt( cls, provider: str, messages: List[BaseMessage] ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": prompt = convert_messages_to_prompt_llama(messages=messages) elif provider == "amazon": prompt = convert_messages_to_prompt_anthropic( messages=messages, human_prompt="\n\nUser:", ai_prompt="\n\nBot:", ) else: raise NotImplementedError( f"Provider {provider} model does not support chat." ) return prompt @classmethod def format_messages( cls, provider: str, messages: List[BaseMessage] ) -> Tuple[Optional[str], List[Dict]]: if provider == "anthropic": return _format_anthropic_messages(messages) raise NotImplementedError( f"Provider {provider} not supported for format_messages" ) _message_type_lookups = {"human": "user", "ai": "assistant"} class BedrockChat(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" @property def _llm_type(self) -> str: """Return type of chat model.""" return "amazon_bedrock_chat" @classmethod def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by Langchain.""" return True @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "bedrock"] @property def lc_attributes(self) -> Dict[str, Any]: attributes: Dict[str, Any] = {} if self.region_name: attributes["region_name"] = self.region_name return attributes class Config: """Configuration for this pydantic object.""" extra = Extra.forbid def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: provider = self._get_provider() system = None formatted_messages = None if provider == "anthropic": prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages ) for chunk in self._prepare_input_and_invoke_stream( prompt=prompt, system=system, messages=formatted_messages, stop=stop, run_manager=run_manager, **kwargs, ): delta = chunk.text yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: completion = "" if self.streaming: for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: provider = self._get_provider() system = None formatted_messages = None params: Dict[str, Any] = {**kwargs} if provider == "anthropic": prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages ) if stop: params["stop_sequences"] = stop completion = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, system=system, messages=formatted_messages, **params, ) return ChatResult( generations=[ChatGeneration(message=AIMessage(content=completion))] ) def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: return get_num_tokens_anthropic(text) else: return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: if self._model_is_anthropic: return get_token_ids_anthropic(text) else: return super().get_token_ids(text)