2024-03-06 23:46:18 +00:00
|
|
|
import re
|
2024-03-28 18:58:46 +00:00
|
|
|
from collections import defaultdict
|
2024-03-06 23:46:18 +00:00
|
|
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
2024-03-09 01:20:38 +00:00
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
AIMessageChunk,
|
|
|
|
BaseMessage,
|
|
|
|
ChatMessage,
|
|
|
|
HumanMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_core.pydantic_v1 import Extra
|
|
|
|
|
|
|
|
from langchain_community.chat_models.anthropic import (
|
|
|
|
convert_messages_to_prompt_anthropic,
|
|
|
|
)
|
|
|
|
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
|
|
|
|
from langchain_community.llms.bedrock import BedrockBase
|
|
|
|
from langchain_community.utilities.anthropic import (
|
|
|
|
get_num_tokens_anthropic,
|
|
|
|
get_token_ids_anthropic,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-03-09 01:20:38 +00:00
|
|
|
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
|
|
|
|
if isinstance(message, ChatMessage):
|
|
|
|
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
|
|
|
elif isinstance(message, HumanMessage):
|
|
|
|
message_text = f"[INST] {message.content} [/INST]"
|
|
|
|
elif isinstance(message, AIMessage):
|
|
|
|
message_text = f"{message.content}"
|
|
|
|
elif isinstance(message, SystemMessage):
|
|
|
|
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
return message_text
|
|
|
|
|
|
|
|
|
|
|
|
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
|
|
|
|
"""Convert a list of messages to a prompt for mistral."""
|
|
|
|
return "\n".join(
|
|
|
|
[_convert_one_message_to_text_mistral(message) for message in messages]
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-03-06 23:46:18 +00:00
|
|
|
def _format_image(image_url: str) -> Dict:
|
|
|
|
"""
|
|
|
|
Formats an image of format data:image/jpeg;base64,{b64_string}
|
|
|
|
to a dict for anthropic api
|
|
|
|
|
|
|
|
{
|
|
|
|
"type": "base64",
|
|
|
|
"media_type": "image/jpeg",
|
|
|
|
"data": "/9j/4AAQSkZJRg...",
|
|
|
|
}
|
|
|
|
|
|
|
|
And throws an error if it's not a b64 image
|
|
|
|
"""
|
|
|
|
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
|
|
|
match = re.match(regex, image_url)
|
|
|
|
if match is None:
|
|
|
|
raise ValueError(
|
|
|
|
"Anthropic only supports base64-encoded images currently."
|
|
|
|
" Example: data:image/png;base64,'/9j/4AAQSk'..."
|
|
|
|
)
|
|
|
|
return {
|
|
|
|
"type": "base64",
|
|
|
|
"media_type": match.group("media_type"),
|
|
|
|
"data": match.group("data"),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _format_anthropic_messages(
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
) -> Tuple[Optional[str], List[Dict]]:
|
|
|
|
"""Format messages for anthropic."""
|
|
|
|
|
|
|
|
"""
|
|
|
|
[
|
|
|
|
{
|
|
|
|
"role": _message_type_lookups[m.type],
|
|
|
|
"content": [_AnthropicMessageContent(text=m.content).dict()],
|
|
|
|
}
|
|
|
|
for m in messages
|
|
|
|
]
|
|
|
|
"""
|
|
|
|
system: Optional[str] = None
|
|
|
|
formatted_messages: List[Dict] = []
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
if message.type == "system":
|
|
|
|
if i != 0:
|
|
|
|
raise ValueError("System message must be at beginning of message list.")
|
|
|
|
if not isinstance(message.content, str):
|
|
|
|
raise ValueError(
|
|
|
|
"System message must be a string, "
|
|
|
|
f"instead was: {type(message.content)}"
|
|
|
|
)
|
|
|
|
system = message.content
|
|
|
|
continue
|
|
|
|
|
|
|
|
role = _message_type_lookups[message.type]
|
|
|
|
content: Union[str, List[Dict]]
|
|
|
|
|
|
|
|
if not isinstance(message.content, str):
|
|
|
|
# parse as dict
|
|
|
|
assert isinstance(
|
|
|
|
message.content, list
|
|
|
|
), "Anthropic message content must be str or list of dicts"
|
|
|
|
|
|
|
|
# populate content
|
|
|
|
content = []
|
|
|
|
for item in message.content:
|
|
|
|
if isinstance(item, str):
|
|
|
|
content.append(
|
|
|
|
{
|
|
|
|
"type": "text",
|
|
|
|
"text": item,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
elif isinstance(item, dict):
|
|
|
|
if "type" not in item:
|
|
|
|
raise ValueError("Dict content item must have a type key")
|
|
|
|
if item["type"] == "image_url":
|
|
|
|
# convert format
|
|
|
|
source = _format_image(item["image_url"]["url"])
|
|
|
|
content.append(
|
|
|
|
{
|
|
|
|
"type": "image",
|
|
|
|
"source": source,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
content.append(item)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Content items must be str or dict, instead was: {type(item)}"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
content = message.content
|
|
|
|
|
|
|
|
formatted_messages.append(
|
|
|
|
{
|
|
|
|
"role": role,
|
|
|
|
"content": content,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
return system, formatted_messages
|
|
|
|
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
class ChatPromptAdapter:
|
|
|
|
"""Adapter class to prepare the inputs from Langchain to prompt format
|
|
|
|
that Chat model expects.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def convert_messages_to_prompt(
|
|
|
|
cls, provider: str, messages: List[BaseMessage]
|
|
|
|
) -> str:
|
|
|
|
if provider == "anthropic":
|
|
|
|
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
|
|
|
elif provider == "meta":
|
|
|
|
prompt = convert_messages_to_prompt_llama(messages=messages)
|
2024-03-09 01:20:38 +00:00
|
|
|
elif provider == "mistral":
|
|
|
|
prompt = convert_messages_to_prompt_mistral(messages=messages)
|
2024-01-22 19:37:23 +00:00
|
|
|
elif provider == "amazon":
|
|
|
|
prompt = convert_messages_to_prompt_anthropic(
|
|
|
|
messages=messages,
|
|
|
|
human_prompt="\n\nUser:",
|
|
|
|
ai_prompt="\n\nBot:",
|
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Provider {provider} model does not support chat."
|
|
|
|
)
|
|
|
|
return prompt
|
|
|
|
|
2024-03-06 23:46:18 +00:00
|
|
|
@classmethod
|
|
|
|
def format_messages(
|
|
|
|
cls, provider: str, messages: List[BaseMessage]
|
|
|
|
) -> Tuple[Optional[str], List[Dict]]:
|
|
|
|
if provider == "anthropic":
|
|
|
|
return _format_anthropic_messages(messages)
|
|
|
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Provider {provider} not supported for format_messages"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
class BedrockChat(BaseChatModel, BedrockBase):
|
|
|
|
"""A chat model that uses the Bedrock API."""
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
"""Return type of chat model."""
|
|
|
|
return "amazon_bedrock_chat"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def is_lc_serializable(cls) -> bool:
|
|
|
|
"""Return whether this model can be serialized by Langchain."""
|
|
|
|
return True
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
|
|
"""Get the namespace of the langchain object."""
|
|
|
|
return ["langchain", "chat_models", "bedrock"]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
if self.region_name:
|
|
|
|
attributes["region_name"] = self.region_name
|
|
|
|
|
|
|
|
return attributes
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
|
|
provider = self._get_provider()
|
2024-03-28 18:58:46 +00:00
|
|
|
prompt, system, formatted_messages = None, None, None
|
|
|
|
|
2024-03-06 23:46:18 +00:00
|
|
|
if provider == "anthropic":
|
|
|
|
system, formatted_messages = ChatPromptAdapter.format_messages(
|
|
|
|
provider, messages
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
|
|
|
provider=provider, messages=messages
|
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
for chunk in self._prepare_input_and_invoke_stream(
|
2024-03-06 23:46:18 +00:00
|
|
|
prompt=prompt,
|
|
|
|
system=system,
|
|
|
|
messages=formatted_messages,
|
|
|
|
stop=stop,
|
|
|
|
run_manager=run_manager,
|
|
|
|
**kwargs,
|
2023-12-11 21:53:30 +00:00
|
|
|
):
|
|
|
|
delta = chunk.text
|
|
|
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> ChatResult:
|
|
|
|
completion = ""
|
2024-03-28 18:58:46 +00:00
|
|
|
llm_output: Dict[str, Any] = {"model_id": self.model_id}
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
if self.streaming:
|
|
|
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
|
|
|
completion += chunk.text
|
|
|
|
else:
|
|
|
|
provider = self._get_provider()
|
2024-03-28 18:58:46 +00:00
|
|
|
prompt, system, formatted_messages = None, None, None
|
2023-12-11 21:53:30 +00:00
|
|
|
params: Dict[str, Any] = {**kwargs}
|
2024-03-28 18:58:46 +00:00
|
|
|
|
2024-03-06 23:46:18 +00:00
|
|
|
if provider == "anthropic":
|
|
|
|
system, formatted_messages = ChatPromptAdapter.format_messages(
|
|
|
|
provider, messages
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
|
|
|
provider=provider, messages=messages
|
|
|
|
)
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
if stop:
|
|
|
|
params["stop_sequences"] = stop
|
|
|
|
|
2024-03-28 18:58:46 +00:00
|
|
|
completion, usage_info = self._prepare_input_and_invoke(
|
2024-03-06 23:46:18 +00:00
|
|
|
prompt=prompt,
|
|
|
|
stop=stop,
|
|
|
|
run_manager=run_manager,
|
|
|
|
system=system,
|
|
|
|
messages=formatted_messages,
|
|
|
|
**params,
|
2023-12-11 21:53:30 +00:00
|
|
|
)
|
|
|
|
|
2024-03-28 18:58:46 +00:00
|
|
|
llm_output["usage"] = usage_info
|
|
|
|
|
2024-03-06 23:46:18 +00:00
|
|
|
return ChatResult(
|
2024-03-28 18:58:46 +00:00
|
|
|
generations=[ChatGeneration(message=AIMessage(content=completion))],
|
|
|
|
llm_output=llm_output,
|
2024-03-06 23:46:18 +00:00
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
|
2024-03-28 18:58:46 +00:00
|
|
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
|
|
|
final_usage: Dict[str, int] = defaultdict(int)
|
|
|
|
final_output = {}
|
|
|
|
for output in llm_outputs:
|
|
|
|
output = output or {}
|
2024-04-09 14:18:48 +00:00
|
|
|
usage = output.get("usage", {})
|
2024-03-28 18:58:46 +00:00
|
|
|
for token_type, token_count in usage.items():
|
|
|
|
final_usage[token_type] += token_count
|
|
|
|
final_output.update(output)
|
|
|
|
final_output["usage"] = final_usage
|
|
|
|
return final_output
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
|
|
if self._model_is_anthropic:
|
|
|
|
return get_num_tokens_anthropic(text)
|
|
|
|
else:
|
|
|
|
return super().get_num_tokens(text)
|
|
|
|
|
|
|
|
def get_token_ids(self, text: str) -> List[int]:
|
|
|
|
if self._model_is_anthropic:
|
|
|
|
return get_token_ids_anthropic(text)
|
|
|
|
else:
|
|
|
|
return super().get_token_ids(text)
|