import re from collections import defaultdict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, ChatMessage, HumanMessage, SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Extra from langchain_community.chat_models.anthropic import ( convert_messages_to_prompt_anthropic, ) from langchain_community.chat_models.meta import convert_messages_to_prompt_llama from langchain_community.llms.bedrock import BedrockBase from langchain_community.utilities.anthropic import ( get_num_tokens_anthropic, get_token_ids_anthropic, ) def _convert_one_message_to_text_mistral(message: BaseMessage) -> str: if isinstance(message, ChatMessage): message_text = f"\n\n{message.role.capitalize()}: {message.content}" elif isinstance(message, HumanMessage): message_text = f"[INST] {message.content} [/INST]" elif isinstance(message, AIMessage): message_text = f"{message.content}" elif isinstance(message, SystemMessage): message_text = f"<> {message.content} <>" else: raise ValueError(f"Got unknown type {message}") return message_text def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: """Convert a list of messages to a prompt for mistral.""" return "\n".join( [_convert_one_message_to_text_mistral(message) for message in messages] ) def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} to a dict for anthropic api { "type": "base64", "media_type": "image/jpeg", "data": "/9j/4AAQSkZJRg...", } And throws an error if it's not a b64 image """ regex = r"^data:(?Pimage/.+);base64,(?P.+)$" match = re.match(regex, image_url) if match is None: raise ValueError( "Anthropic only supports base64-encoded images currently." " Example: data:image/png;base64,'/9j/4AAQSk'..." ) return { "type": "base64", "media_type": match.group("media_type"), "data": match.group("data"), } def _format_anthropic_messages( messages: List[BaseMessage], ) -> Tuple[Optional[str], List[Dict]]: """Format messages for anthropic.""" """ [ { "role": _message_type_lookups[m.type], "content": [_AnthropicMessageContent(text=m.content).dict()], } for m in messages ] """ system: Optional[str] = None formatted_messages: List[Dict] = [] for i, message in enumerate(messages): if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") if not isinstance(message.content, str): raise ValueError( "System message must be a string, " f"instead was: {type(message.content)}" ) system = message.content continue role = _message_type_lookups[message.type] content: Union[str, List[Dict]] if not isinstance(message.content, str): # parse as dict assert isinstance( message.content, list ), "Anthropic message content must be str or list of dicts" # populate content content = [] for item in message.content: if isinstance(item, str): content.append( { "type": "text", "text": item, } ) elif isinstance(item, dict): if "type" not in item: raise ValueError("Dict content item must have a type key") if item["type"] == "image_url": # convert format source = _format_image(item["image_url"]["url"]) content.append( { "type": "image", "source": source, } ) else: content.append(item) else: raise ValueError( f"Content items must be str or dict, instead was: {type(item)}" ) else: content = message.content formatted_messages.append( { "role": role, "content": content, } ) return system, formatted_messages class ChatPromptAdapter: """Adapter class to prepare the inputs from Langchain to prompt format that Chat model expects. """ @classmethod def convert_messages_to_prompt( cls, provider: str, messages: List[BaseMessage] ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": prompt = convert_messages_to_prompt_llama(messages=messages) elif provider == "mistral": prompt = convert_messages_to_prompt_mistral(messages=messages) elif provider == "amazon": prompt = convert_messages_to_prompt_anthropic( messages=messages, human_prompt="\n\nUser:", ai_prompt="\n\nBot:", ) else: raise NotImplementedError( f"Provider {provider} model does not support chat." ) return prompt @classmethod def format_messages( cls, provider: str, messages: List[BaseMessage] ) -> Tuple[Optional[str], List[Dict]]: if provider == "anthropic": return _format_anthropic_messages(messages) raise NotImplementedError( f"Provider {provider} not supported for format_messages" ) _message_type_lookups = {"human": "user", "ai": "assistant"} class BedrockChat(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" @property def _llm_type(self) -> str: """Return type of chat model.""" return "amazon_bedrock_chat" @classmethod def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by Langchain.""" return True @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "bedrock"] @property def lc_attributes(self) -> Dict[str, Any]: attributes: Dict[str, Any] = {} if self.region_name: attributes["region_name"] = self.region_name return attributes class Config: """Configuration for this pydantic object.""" extra = Extra.forbid def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: provider = self._get_provider() prompt, system, formatted_messages = None, None, None if provider == "anthropic": system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages ) for chunk in self._prepare_input_and_invoke_stream( prompt=prompt, system=system, messages=formatted_messages, stop=stop, run_manager=run_manager, **kwargs, ): delta = chunk.text yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: completion = "" llm_output: Dict[str, Any] = {"model_id": self.model_id} if self.streaming: for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: provider = self._get_provider() prompt, system, formatted_messages = None, None, None params: Dict[str, Any] = {**kwargs} if provider == "anthropic": system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages ) if stop: params["stop_sequences"] = stop completion, usage_info = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, system=system, messages=formatted_messages, **params, ) llm_output["usage"] = usage_info return ChatResult( generations=[ChatGeneration(message=AIMessage(content=completion))], llm_output=llm_output, ) def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: final_usage: Dict[str, int] = defaultdict(int) final_output = {} for output in llm_outputs: output = output or {} usage = output.get("usage", {}) for token_type, token_count in usage.items(): final_usage[token_type] += token_count final_output.update(output) final_output["usage"] = final_usage return final_output def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: return get_num_tokens_anthropic(text) else: return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: if self._model_is_anthropic: return get_token_ids_anthropic(text) else: return super().get_token_ids(text)