|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
use super::{
|
|
|
|
|
catch_error, claude::*, prompt_format::*, BedrockClient, Client, CompletionData,
|
|
|
|
|
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
|
|
|
|
SseHandler,
|
|
|
|
|
catch_error, claude::*, prompt_format::*, BedrockClient, ChatCompletionsData,
|
|
|
|
|
ChatCompletionsOutput, Client, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
|
|
|
|
|
PromptKind, SseHandler,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
|
|
|
@ -38,25 +38,25 @@ pub struct BedrockConfig {
|
|
|
|
|
impl Client for BedrockClient {
|
|
|
|
|
client_common_fns!();
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(
|
|
|
|
|
async fn chat_completions_inner(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: CompletionData,
|
|
|
|
|
) -> Result<CompletionOutput> {
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
let model_category = ModelCategory::from_str(self.model.name())?;
|
|
|
|
|
let builder = self.request_builder(client, data, &model_category)?;
|
|
|
|
|
send_message(builder, &model_category).await
|
|
|
|
|
let builder = self.chat_completions_builder(client, data, &model_category)?;
|
|
|
|
|
chat_completions(builder, &model_category).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming_inner(
|
|
|
|
|
async fn chat_completions_streaming_inner(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
|
data: CompletionData,
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let model_category = ModelCategory::from_str(self.model.name())?;
|
|
|
|
|
let builder = self.request_builder(client, data, &model_category)?;
|
|
|
|
|
send_message_streaming(builder, handler, &model_category).await
|
|
|
|
|
let builder = self.chat_completions_builder(client, data, &model_category)?;
|
|
|
|
|
chat_completions_streaming(builder, handler, &model_category).await
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -81,10 +81,10 @@ impl BedrockClient {
|
|
|
|
|
("region", "AWS Region", true, PromptKind::String),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
fn request_builder(
|
|
|
|
|
fn chat_completions_builder(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: CompletionData,
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
@ -101,7 +101,7 @@ impl BedrockClient {
|
|
|
|
|
|
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
|
|
|
|
|
|
let mut body = build_body(data, &self.model, model_category)?;
|
|
|
|
|
let mut body = build_chat_completions_body(data, &self.model, model_category)?;
|
|
|
|
|
self.patch_request_body(&mut body);
|
|
|
|
|
|
|
|
|
|
let builder = aws_fetch(
|
|
|
|
@ -126,10 +126,10 @@ impl BedrockClient {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message(
|
|
|
|
|
async fn chat_completions(
|
|
|
|
|
builder: RequestBuilder,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
) -> Result<CompletionOutput> {
|
|
|
|
|
) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
let res = builder.send().await?;
|
|
|
|
|
let status = res.status();
|
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
@ -140,13 +140,13 @@ async fn send_message(
|
|
|
|
|
|
|
|
|
|
debug!("non-stream-data: {data}");
|
|
|
|
|
match model_category {
|
|
|
|
|
ModelCategory::Anthropic => claude_extract_completion(&data),
|
|
|
|
|
ModelCategory::MetaLlama3 => llama_extract_completion(&data),
|
|
|
|
|
ModelCategory::Mistral => mistral_extract_completion(&data),
|
|
|
|
|
ModelCategory::Anthropic => claude_extract_chat_completions(&data),
|
|
|
|
|
ModelCategory::MetaLlama3 => llama_extract_chat_completions(&data),
|
|
|
|
|
ModelCategory::Mistral => mistral_extract_chat_completions(&data),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming(
|
|
|
|
|
async fn chat_completions_streaming(
|
|
|
|
|
builder: RequestBuilder,
|
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
@ -211,14 +211,14 @@ async fn send_message_streaming(
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_body(
|
|
|
|
|
data: CompletionData,
|
|
|
|
|
fn build_chat_completions_body(
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
model: &Model,
|
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
match model_category {
|
|
|
|
|
ModelCategory::Anthropic => {
|
|
|
|
|
let mut body = claude_build_body(data, model)?;
|
|
|
|
|
let mut body = claude_build_chat_completions_body(data, model)?;
|
|
|
|
|
if let Some(body_obj) = body.as_object_mut() {
|
|
|
|
|
body_obj.remove("model");
|
|
|
|
|
body_obj.remove("stream");
|
|
|
|
@ -226,13 +226,19 @@ fn build_body(
|
|
|
|
|
body["anthropic_version"] = "bedrock-2023-05-31".into();
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
ModelCategory::MetaLlama3 => meta_llama_build_body(data, model, LLAMA3_PROMPT_FORMAT),
|
|
|
|
|
ModelCategory::Mistral => mistral_build_body(data, model),
|
|
|
|
|
ModelCategory::MetaLlama3 => {
|
|
|
|
|
meta_llama_build_chat_completions_body(data, model, LLAMA3_PROMPT_FORMAT)
|
|
|
|
|
}
|
|
|
|
|
ModelCategory::Mistral => mistral_build_chat_completions_body(data, model),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat) -> Result<Value> {
|
|
|
|
|
let CompletionData {
|
|
|
|
|
fn meta_llama_build_chat_completions_body(
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
model: &Model,
|
|
|
|
|
pt: PromptFormat,
|
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
let ChatCompletionsData {
|
|
|
|
|
messages,
|
|
|
|
|
temperature,
|
|
|
|
|
top_p,
|
|
|
|
@ -255,8 +261,8 @@ fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat)
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
|
|
|
|
let CompletionData {
|
|
|
|
|
fn mistral_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
|
let ChatCompletionsData {
|
|
|
|
|
messages,
|
|
|
|
|
temperature,
|
|
|
|
|
top_p,
|
|
|
|
@ -279,11 +285,11 @@ fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
|
|
|
|
Ok(body)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
|
|
|
|
|
fn llama_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
let text = data["generation"]
|
|
|
|
|
.as_str()
|
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
|
let output = CompletionOutput {
|
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
|
text: text.to_string(),
|
|
|
|
|
tool_calls: vec![],
|
|
|
|
|
id: None,
|
|
|
|
@ -293,11 +299,11 @@ fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn mistral_extract_completion(data: &Value) -> Result<CompletionOutput> {
|
|
|
|
|
fn mistral_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
let text = data["outputs"][0]["text"]
|
|
|
|
|
.as_str()
|
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
|
Ok(CompletionOutput::new(text))
|
|
|
|
|
Ok(ChatCompletionsOutput::new(text))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
|
|
|