mirror of https://github.com/sigoden/aichat
feat: support bedrock client (#450)
parent
615bab215b
commit
1f2b626703
@ -0,0 +1,443 @@
|
||||
use super::claude::claude_build_body;
|
||||
use super::{
|
||||
catch_error, generate_prompt, BedrockClient, Client, ExtraConfig, Model, ModelConfig,
|
||||
PromptFormat, PromptType, ReplyHandler, SendData, LLAMA2_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT,
|
||||
};
|
||||
|
||||
use crate::utils::PromptKind;
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use aws_smithy_eventstream::smithy::parse_response_headers;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use bytes::BytesMut;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_util::StreamExt;
|
||||
use hmac::{Hmac, Mac};
|
||||
use indexmap::IndexMap;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderName, HeaderValue},
|
||||
Client as ReqwestClient, Method, RequestBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BedrockConfig {
|
||||
pub name: Option<String>,
|
||||
pub access_key_id: Option<String>,
|
||||
pub secret_access_key: Option<String>,
|
||||
pub region: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelConfig>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for BedrockClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
||||
let model_category = ModelCategory::from_str(&self.model.name)?;
|
||||
let builder = self.request_builder(client, data, &model_category)?;
|
||||
send_message(builder, &model_category).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyHandler,
|
||||
data: SendData,
|
||||
) -> Result<()> {
|
||||
let model_category = ModelCategory::from_str(&self.model.name)?;
|
||||
let builder = self.request_builder(client, data, &model_category)?;
|
||||
send_message_streaming(builder, handler, &model_category).await
|
||||
}
|
||||
}
|
||||
|
||||
impl BedrockClient {
|
||||
config_get_fn!(access_key_id, get_access_key_id);
|
||||
config_get_fn!(secret_access_key, get_secret_access_key);
|
||||
config_get_fn!(region, get_region);
|
||||
|
||||
pub const PROMPTS: [PromptType<'static>; 3] = [
|
||||
(
|
||||
"access_key_id",
|
||||
"AWS Access Key ID",
|
||||
true,
|
||||
PromptKind::String,
|
||||
),
|
||||
(
|
||||
"secret_access_key",
|
||||
"AWS Secret Access Key",
|
||||
true,
|
||||
PromptKind::String,
|
||||
),
|
||||
("region", "AWS Region", true, PromptKind::String),
|
||||
];
|
||||
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
let secret_access_key = self.get_secret_access_key()?;
|
||||
let region = self.get_region()?;
|
||||
|
||||
let model_name = &self.model.name;
|
||||
let uri = if data.stream {
|
||||
format!("/model/{model_name}/invoke-with-response-stream")
|
||||
} else {
|
||||
format!("/model/{model_name}/invoke")
|
||||
};
|
||||
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
||||
|
||||
let headers = IndexMap::new();
|
||||
|
||||
let mut body = build_body(data, &self.model, model_category)?;
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
&AwsCredentials {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
region,
|
||||
},
|
||||
AwsRequest {
|
||||
method: Method::POST,
|
||||
host,
|
||||
service: "bedrock".into(),
|
||||
uri,
|
||||
querystring: "".into(),
|
||||
headers,
|
||||
body: body.to_string(),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_message(builder: RequestBuilder, model_category: &ModelCategory) -> Result<String> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
|
||||
if status != 200 {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
let output = match model_category {
|
||||
ModelCategory::Anthropic => data["content"][0]["text"].as_str(),
|
||||
ModelCategory::MetaLlama2 | ModelCategory::MetaLlama3 => data["generation"].as_str(),
|
||||
ModelCategory::Mistral => data["outputs"][0]["text"].as_str(),
|
||||
};
|
||||
|
||||
let output = output.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
||||
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
async fn send_message_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut ReplyHandler,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
if status != 200 {
|
||||
let data: Value = res.json().await?;
|
||||
catch_error(&data, status.as_u16())?;
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
let mut stream = res.bytes_stream();
|
||||
let mut buffer = BytesMut::new();
|
||||
let mut decoder = MessageFrameDecoder::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
buffer.extend_from_slice(&chunk);
|
||||
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
|
||||
let response_headers = parse_response_headers(&message)?;
|
||||
let message_type = response_headers.message_type.as_str();
|
||||
let smithy_type = response_headers.smithy_type.as_str();
|
||||
match (message_type, smithy_type) {
|
||||
("event", "chunk") => {
|
||||
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
|
||||
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
|
||||
})?;
|
||||
debug!("bedrock chunk: {data}");
|
||||
match model_category {
|
||||
ModelCategory::Anthropic => {
|
||||
if let Some(typ) = data["type"].as_str() {
|
||||
if typ == "content_block_delta" {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelCategory::MetaLlama2 | ModelCategory::MetaLlama3 => {
|
||||
if let Some(text) = data["generation"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
if let Some(text) = data["outputs"][0]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
("exception", _) => {
|
||||
let payload = STANDARD.decode(message.payload())?;
|
||||
let data = String::from_utf8_lossy(&payload);
|
||||
|
||||
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
|
||||
}
|
||||
_ => {
|
||||
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result<Value> {
|
||||
match model_category {
|
||||
ModelCategory::Anthropic => {
|
||||
let mut body = claude_build_body(data, model)?;
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj.remove("model");
|
||||
}
|
||||
body["anthropic_version"] = "bedrock-2023-05-31".into();
|
||||
Ok(body)
|
||||
}
|
||||
ModelCategory::MetaLlama2 => meta_llama_build_body(data, model, LLAMA2_PROMPT_FORMAT),
|
||||
ModelCategory::MetaLlama3 => meta_llama_build_body(data, model, LLAMA3_PROMPT_FORMAT),
|
||||
ModelCategory::Mistral => mistral_build_body(data, model),
|
||||
}
|
||||
}
|
||||
|
||||
fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result<Value> {
|
||||
let SendData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
stream: _,
|
||||
} = data;
|
||||
let prompt = generate_prompt(&messages, pt)?;
|
||||
let mut body = json!({ "prompt": prompt });
|
||||
|
||||
if let Some(v) = model.max_output_tokens {
|
||||
body["max_gen_len"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
stream: _,
|
||||
} = data;
|
||||
let prompt = generate_prompt(&messages, LLAMA2_PROMPT_FORMAT)?;
|
||||
let mut body = json!({ "prompt": prompt });
|
||||
|
||||
if let Some(v) = model.max_output_tokens {
|
||||
body["max_tokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelCategory {
|
||||
Anthropic,
|
||||
MetaLlama2,
|
||||
MetaLlama3,
|
||||
Mistral,
|
||||
}
|
||||
|
||||
impl FromStr for ModelCategory {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
if s.starts_with("anthropic.") {
|
||||
Ok(ModelCategory::Anthropic)
|
||||
} else if s.starts_with("meta.llama2") {
|
||||
Ok(ModelCategory::MetaLlama2)
|
||||
} else if s.starts_with("meta.llama3") {
|
||||
Ok(ModelCategory::MetaLlama3)
|
||||
} else if s.starts_with("mistral") {
|
||||
Ok(ModelCategory::Mistral)
|
||||
} else {
|
||||
unsupported_model!(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
region: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsRequest {
|
||||
method: Method,
|
||||
host: String,
|
||||
service: String,
|
||||
uri: String,
|
||||
querystring: String,
|
||||
headers: IndexMap<String, String>,
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn aws_fetch(
|
||||
client: &ReqwestClient,
|
||||
credentials: &AwsCredentials,
|
||||
request: AwsRequest,
|
||||
) -> Result<RequestBuilder> {
|
||||
let AwsRequest {
|
||||
method,
|
||||
host,
|
||||
service,
|
||||
uri,
|
||||
querystring,
|
||||
mut headers,
|
||||
body,
|
||||
} = request;
|
||||
let region = &credentials.region;
|
||||
|
||||
let endpoint = format!("https://{}{}", host, uri);
|
||||
|
||||
let now: DateTime<Utc> = Utc::now();
|
||||
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
|
||||
let date_stamp = amz_date[0..8].to_string();
|
||||
headers.insert("host".into(), host.clone());
|
||||
headers.insert("x-amz-date".into(), amz_date.clone());
|
||||
|
||||
let canonical_headers = headers
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{}:{}\n", key, value))
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
let signed_headers = headers
|
||||
.iter()
|
||||
.map(|(key, _)| key.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(";");
|
||||
|
||||
let payload_hash = sha256(&body);
|
||||
|
||||
let canonical_request = format!(
|
||||
"{}\n{}\n{}\n{}\n{}\n{}",
|
||||
method,
|
||||
encode_uri(&uri),
|
||||
querystring,
|
||||
canonical_headers,
|
||||
signed_headers,
|
||||
payload_hash
|
||||
);
|
||||
|
||||
let algorithm = "AWS4-HMAC-SHA256";
|
||||
let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, region, service);
|
||||
let string_to_sign = format!(
|
||||
"{}\n{}\n{}\n{}",
|
||||
algorithm,
|
||||
amz_date,
|
||||
credential_scope,
|
||||
sha256(&canonical_request)
|
||||
);
|
||||
|
||||
let signing_key = gen_signing_key(
|
||||
&credentials.secret_access_key,
|
||||
&date_stamp,
|
||||
region,
|
||||
&service,
|
||||
);
|
||||
let signature = sign(&signing_key, &string_to_sign);
|
||||
let signature = hex_encode(&signature);
|
||||
|
||||
let authorization_header = format!(
|
||||
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
|
||||
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
|
||||
);
|
||||
|
||||
headers.insert("authorization".into(), authorization_header);
|
||||
|
||||
let mut req_headers = HeaderMap::new();
|
||||
for (k, v) in &headers {
|
||||
req_headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
|
||||
}
|
||||
|
||||
debug!("Bedrock Request: {endpoint} {body}");
|
||||
|
||||
let requst_builder = client
|
||||
.request(method, endpoint)
|
||||
.headers(req_headers)
|
||||
.body(body);
|
||||
Ok(requst_builder)
|
||||
}
|
||||
|
||||
fn sha256(data: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
fn sign(key: &[u8], msg: &str) -> Vec<u8> {
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
|
||||
mac.update(msg.as_bytes());
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
|
||||
let k_date = sign(format!("AWS4{}", key).as_bytes(), date_stamp);
|
||||
let k_region = sign(&k_date, region);
|
||||
let k_service = sign(&k_region, service);
|
||||
sign(&k_service, "aws4_request")
|
||||
}
|
||||
|
||||
fn hex_encode(bytes: &[u8]) -> String {
|
||||
bytes
|
||||
.iter()
|
||||
.fold(String::new(), |acc, b| acc + &format!("{:02x}", b))
|
||||
}
|
||||
|
||||
fn encode_uri(uri: &str) -> String {
|
||||
uri.split('/')
|
||||
.map(|v| urlencoding::encode(v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("/")
|
||||
}
|
||||
|
||||
fn decode_chunk(data: &[u8]) -> Option<Value> {
|
||||
let data = serde_json::from_slice::<Value>(data).ok()?;
|
||||
let data = data["bytes"].as_str()?;
|
||||
let data = STANDARD.decode(data).ok()?;
|
||||
serde_json::from_slice(&data).ok()
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
use super::message::*;
|
||||
|
||||
pub struct PromptFormat<'a> {
|
||||
pub bos_token: &'a str,
|
||||
pub system_pre_message: &'a str,
|
||||
pub system_post_message: &'a str,
|
||||
pub user_pre_message: &'a str,
|
||||
pub user_post_message: &'a str,
|
||||
pub assistant_pre_message: &'a str,
|
||||
pub assistant_post_message: &'a str,
|
||||
}
|
||||
|
||||
pub const LLAMA2_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
|
||||
bos_token: "<s>",
|
||||
system_pre_message: "[INST] <<SYS>>",
|
||||
system_post_message: "<</SYS>> [/INST]",
|
||||
user_pre_message: "[INST]",
|
||||
user_post_message: "[/INST]",
|
||||
assistant_pre_message: "",
|
||||
assistant_post_message: "</s>",
|
||||
};
|
||||
|
||||
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
|
||||
bos_token: "<|begin_of_text|>",
|
||||
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
system_post_message: "<|eot_id|>",
|
||||
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
user_post_message: "<|eot_id|>",
|
||||
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
assistant_post_message: "<|eot_id|>",
|
||||
};
|
||||
|
||||
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result<String> {
|
||||
let PromptFormat {
|
||||
bos_token,
|
||||
system_pre_message,
|
||||
system_post_message,
|
||||
user_pre_message,
|
||||
user_post_message,
|
||||
assistant_pre_message,
|
||||
assistant_post_message,
|
||||
} = format;
|
||||
let mut prompt = bos_token.to_string();
|
||||
let mut image_urls = vec![];
|
||||
for message in messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Array(list) => {
|
||||
let mut parts = vec![];
|
||||
for item in list {
|
||||
match item {
|
||||
MessageContentPart::Text { text } => parts.push(text.clone()),
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
image_urls.push(url.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n\n")
|
||||
}
|
||||
};
|
||||
match role {
|
||||
MessageRole::System => prompt.push_str(&format!(
|
||||
"{system_pre_message}{content}{system_post_message}"
|
||||
)),
|
||||
MessageRole::Assistant => prompt.push_str(&format!(
|
||||
"{assistant_pre_message}{content}{assistant_post_message}"
|
||||
)),
|
||||
MessageRole::User => {
|
||||
prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !image_urls.is_empty() {
|
||||
anyhow::bail!("The model does not support images: {:?}", image_urls);
|
||||
}
|
||||
prompt.push_str(assistant_pre_message);
|
||||
Ok(prompt)
|
||||
}
|
Loading…
Reference in New Issue