|
|
@ -1,11 +1,8 @@
|
|
|
|
use super::claude::*;
|
|
|
|
|
|
|
|
use super::prompt_format::*;
|
|
|
|
|
|
|
|
use super::*;
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
|
|
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
|
|
|
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
|
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, bail, Result};
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
|
|
|
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
|
|
|
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
|
|
|
use aws_smithy_eventstream::smithy::parse_response_headers;
|
|
|
|
use aws_smithy_eventstream::smithy::parse_response_headers;
|
|
|
|
use bytes::BytesMut;
|
|
|
|
use bytes::BytesMut;
|
|
|
@ -32,32 +29,6 @@ pub struct BedrockConfig {
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
|
|
|
impl Client for BedrockClient {
|
|
|
|
|
|
|
|
client_common_fns!();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_inner(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
|
|
|
) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
|
|
|
let model_category = ModelCategory::from_str(self.model.name())?;
|
|
|
|
|
|
|
|
let builder = self.chat_completions_builder(client, data, &model_category)?;
|
|
|
|
|
|
|
|
chat_completions(builder, &model_category).await
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_streaming_inner(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
|
|
|
let model_category = ModelCategory::from_str(self.model.name())?;
|
|
|
|
|
|
|
|
let builder = self.chat_completions_builder(client, data, &model_category)?;
|
|
|
|
|
|
|
|
chat_completions_streaming(builder, handler, &model_category).await
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl BedrockClient {
|
|
|
|
impl BedrockClient {
|
|
|
|
config_get_fn!(access_key_id, get_access_key_id);
|
|
|
|
config_get_fn!(access_key_id, get_access_key_id);
|
|
|
|
config_get_fn!(secret_access_key, get_secret_access_key);
|
|
|
|
config_get_fn!(secret_access_key, get_secret_access_key);
|
|
|
@ -83,23 +54,22 @@ impl BedrockClient {
|
|
|
|
&self,
|
|
|
|
&self,
|
|
|
|
client: &ReqwestClient,
|
|
|
|
client: &ReqwestClient,
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
let region = self.get_region()?;
|
|
|
|
let region = self.get_region()?;
|
|
|
|
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
let model_name = &self.model.name();
|
|
|
|
let model_name = &self.model.name();
|
|
|
|
let uri = if data.stream {
|
|
|
|
let uri = if data.stream {
|
|
|
|
format!("/model/{model_name}/invoke-with-response-stream")
|
|
|
|
format!("/model/{model_name}/converse-stream")
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
format!("/model/{model_name}/invoke")
|
|
|
|
format!("/model/{model_name}/converse")
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
|
|
|
|
|
|
|
|
let mut body = build_chat_completions_body(data, &self.model, model_category)?;
|
|
|
|
let mut body = build_chat_completions_body(data, &self.model)?;
|
|
|
|
self.patch_chat_completions_body(&mut body);
|
|
|
|
self.patch_chat_completions_body(&mut body);
|
|
|
|
|
|
|
|
|
|
|
|
let builder = aws_fetch(
|
|
|
|
let builder = aws_fetch(
|
|
|
@ -122,12 +92,61 @@ impl BedrockClient {
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
Ok(builder)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn embeddings_builder(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
data: EmbeddingsData,
|
|
|
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
|
|
|
|
let region = self.get_region()?;
|
|
|
|
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let uri = format!("/model/{}/invoke", self.model.name());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let input_type = match data.query {
|
|
|
|
|
|
|
|
true => "search_query",
|
|
|
|
|
|
|
|
false => "search_document",
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let body = json!({
|
|
|
|
|
|
|
|
"texts": data.texts,
|
|
|
|
|
|
|
|
"input_type": input_type,
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let builder = aws_fetch(
|
|
|
|
|
|
|
|
client,
|
|
|
|
|
|
|
|
&AwsCredentials {
|
|
|
|
|
|
|
|
access_key_id,
|
|
|
|
|
|
|
|
secret_access_key,
|
|
|
|
|
|
|
|
region,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
AwsRequest {
|
|
|
|
|
|
|
|
method: Method::POST,
|
|
|
|
|
|
|
|
host,
|
|
|
|
|
|
|
|
service: "bedrock".into(),
|
|
|
|
|
|
|
|
uri,
|
|
|
|
|
|
|
|
querystring: "".into(),
|
|
|
|
|
|
|
|
headers,
|
|
|
|
|
|
|
|
body: body.to_string(),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions(
|
|
|
|
impl_client_trait!(
|
|
|
|
builder: RequestBuilder,
|
|
|
|
BedrockClient,
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
chat_completions,
|
|
|
|
) -> Result<ChatCompletionsOutput> {
|
|
|
|
chat_completions_streaming,
|
|
|
|
|
|
|
|
embeddings
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let status = res.status();
|
|
|
|
let status = res.status();
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
let data: Value = res.json().await?;
|
|
|
@ -137,17 +156,12 @@ async fn chat_completions(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
debug!("non-stream-data: {data}");
|
|
|
|
debug!("non-stream-data: {data}");
|
|
|
|
match model_category {
|
|
|
|
extract_chat_completions(&data)
|
|
|
|
ModelCategory::Anthropic => claude_extract_chat_completions(&data),
|
|
|
|
|
|
|
|
ModelCategory::MetaLlama3 => llama_extract_chat_completions(&data),
|
|
|
|
|
|
|
|
ModelCategory::Mistral => mistral_extract_chat_completions(&data),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_streaming(
|
|
|
|
async fn chat_completions_streaming(
|
|
|
|
builder: RequestBuilder,
|
|
|
|
builder: RequestBuilder,
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
|
|
|
|
) -> Result<()> {
|
|
|
|
) -> Result<()> {
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let status = res.status();
|
|
|
|
let status = res.status();
|
|
|
@ -156,6 +170,11 @@ async fn chat_completions_streaming(
|
|
|
|
catch_error(&data, status.as_u16())?;
|
|
|
|
catch_error(&data, status.as_u16())?;
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut function_name = String::new();
|
|
|
|
|
|
|
|
let mut function_arguments = String::new();
|
|
|
|
|
|
|
|
let mut function_id = String::new();
|
|
|
|
|
|
|
|
|
|
|
|
let mut stream = res.bytes_stream();
|
|
|
|
let mut stream = res.bytes_stream();
|
|
|
|
let mut buffer = BytesMut::new();
|
|
|
|
let mut buffer = BytesMut::new();
|
|
|
|
let mut decoder = MessageFrameDecoder::new();
|
|
|
|
let mut decoder = MessageFrameDecoder::new();
|
|
|
@ -167,31 +186,53 @@ async fn chat_completions_streaming(
|
|
|
|
let message_type = response_headers.message_type.as_str();
|
|
|
|
let message_type = response_headers.message_type.as_str();
|
|
|
|
let smithy_type = response_headers.smithy_type.as_str();
|
|
|
|
let smithy_type = response_headers.smithy_type.as_str();
|
|
|
|
match (message_type, smithy_type) {
|
|
|
|
match (message_type, smithy_type) {
|
|
|
|
("event", "chunk") => {
|
|
|
|
("event", _) => {
|
|
|
|
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
|
|
|
|
let data: Value = serde_json::from_slice(message.payload())?;
|
|
|
|
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
|
|
|
|
debug!("stream-data: {smithy_type} {data}");
|
|
|
|
})?;
|
|
|
|
match smithy_type {
|
|
|
|
debug!("stream-data: {data}");
|
|
|
|
"contentBlockStart" => {
|
|
|
|
match model_category {
|
|
|
|
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
|
|
|
|
ModelCategory::Anthropic => {
|
|
|
|
if let (Some(id), Some(name)) = (
|
|
|
|
if let Some(typ) = data["type"].as_str() {
|
|
|
|
json_str_from_map(tool_use, "toolUseId"),
|
|
|
|
if typ == "content_block_delta" {
|
|
|
|
json_str_from_map(tool_use, "name"),
|
|
|
|
if let Some(text) = data["delta"]["text"].as_str() {
|
|
|
|
) {
|
|
|
|
handler.text(text)?;
|
|
|
|
if !function_name.is_empty() {
|
|
|
|
|
|
|
|
let arguments: Value =
|
|
|
|
|
|
|
|
function_arguments.parse().with_context(|| {
|
|
|
|
|
|
|
|
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
handler.tool_call(ToolCall::new(
|
|
|
|
|
|
|
|
function_name.clone(),
|
|
|
|
|
|
|
|
arguments,
|
|
|
|
|
|
|
|
Some(function_id.clone()),
|
|
|
|
|
|
|
|
))?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
function_arguments.clear();
|
|
|
|
|
|
|
|
function_name = name.into();
|
|
|
|
|
|
|
|
function_id = id.into();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ModelCategory::MetaLlama3 => {
|
|
|
|
"contentBlockDelta" => {
|
|
|
|
if let Some(text) = data["generation"].as_str() {
|
|
|
|
if let Some(text) = data["delta"]["text"].as_str() {
|
|
|
|
handler.text(text)?;
|
|
|
|
handler.text(text)?;
|
|
|
|
|
|
|
|
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
|
|
|
|
|
|
|
|
function_arguments.push_str(input);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ModelCategory::Mistral => {
|
|
|
|
"contentBlockStop" => {
|
|
|
|
if let Some(text) = data["outputs"][0]["text"].as_str() {
|
|
|
|
if !function_name.is_empty() {
|
|
|
|
handler.text(text)?;
|
|
|
|
let arguments: Value = function_arguments.parse().with_context(|| {
|
|
|
|
|
|
|
|
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
|
|
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
handler.tool_call(ToolCall::new(
|
|
|
|
|
|
|
|
function_name.clone(),
|
|
|
|
|
|
|
|
arguments,
|
|
|
|
|
|
|
|
Some(function_id.clone()),
|
|
|
|
|
|
|
|
))?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
_ => {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
("exception", _) => {
|
|
|
|
("exception", _) => {
|
|
|
@ -209,124 +250,214 @@ async fn chat_completions_streaming(
|
|
|
|
Ok(())
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn build_chat_completions_body(
|
|
|
|
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
let res = builder.send().await?;
|
|
|
|
model: &Model,
|
|
|
|
let status = res.status();
|
|
|
|
model_category: &ModelCategory,
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
|
|
|
match model_category {
|
|
|
|
if !status.is_success() {
|
|
|
|
ModelCategory::Anthropic => {
|
|
|
|
catch_error(&data, status.as_u16())?;
|
|
|
|
let mut body = claude_build_chat_completions_body(data, model)?;
|
|
|
|
|
|
|
|
if let Some(body_obj) = body.as_object_mut() {
|
|
|
|
|
|
|
|
body_obj.remove("model");
|
|
|
|
|
|
|
|
body_obj.remove("stream");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
body["anthropic_version"] = "bedrock-2023-05-31".into();
|
|
|
|
|
|
|
|
Ok(body)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ModelCategory::MetaLlama3 => {
|
|
|
|
|
|
|
|
meta_llama_build_chat_completions_body(data, model, LLAMA3_PROMPT_FORMAT)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ModelCategory::Mistral => mistral_build_chat_completions_body(data, model),
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let res_body: EmbeddingsResBody =
|
|
|
|
|
|
|
|
serde_json::from_value(data).context("Invalid embeddings data")?;
|
|
|
|
|
|
|
|
Ok(res_body.embeddings)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
|
|
|
struct EmbeddingsResBody {
|
|
|
|
|
|
|
|
embeddings: Vec<Vec<f32>>,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn meta_llama_build_chat_completions_body(
|
|
|
|
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
|
|
|
model: &Model,
|
|
|
|
|
|
|
|
pt: PromptFormat,
|
|
|
|
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
|
|
|
let ChatCompletionsData {
|
|
|
|
let ChatCompletionsData {
|
|
|
|
messages,
|
|
|
|
mut messages,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_p,
|
|
|
|
top_p,
|
|
|
|
functions: _,
|
|
|
|
functions,
|
|
|
|
stream: _,
|
|
|
|
stream: _,
|
|
|
|
} = data;
|
|
|
|
} = data;
|
|
|
|
let prompt = generate_prompt(&messages, pt)?;
|
|
|
|
|
|
|
|
let mut body = json!({ "prompt": prompt });
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(v) = model.max_tokens_param() {
|
|
|
|
let system_message = extract_system_message(&mut messages);
|
|
|
|
body["max_gen_len"] = v.into();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(v) = temperature {
|
|
|
|
|
|
|
|
body["temperature"] = v.into();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(v) = top_p {
|
|
|
|
|
|
|
|
body["top_p"] = v.into();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(body)
|
|
|
|
let mut network_image_urls = vec![];
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn mistral_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
let messages: Vec<Value> = messages
|
|
|
|
let ChatCompletionsData {
|
|
|
|
.into_iter()
|
|
|
|
messages,
|
|
|
|
.flat_map(|message| {
|
|
|
|
temperature,
|
|
|
|
let Message { role, content } = message;
|
|
|
|
top_p,
|
|
|
|
match content {
|
|
|
|
functions: _,
|
|
|
|
MessageContent::Text(text) => vec![json!({
|
|
|
|
stream: _,
|
|
|
|
"role": role,
|
|
|
|
} = data;
|
|
|
|
"content": [
|
|
|
|
let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?;
|
|
|
|
{
|
|
|
|
let mut body = json!({ "prompt": prompt });
|
|
|
|
"text": text,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
})],
|
|
|
|
|
|
|
|
MessageContent::Array(list) => {
|
|
|
|
|
|
|
|
let content: Vec<_> = list
|
|
|
|
|
|
|
|
.into_iter()
|
|
|
|
|
|
|
|
.map(|item| match item {
|
|
|
|
|
|
|
|
MessageContentPart::Text { text } => {
|
|
|
|
|
|
|
|
json!({"text": text})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MessageContentPart::ImageUrl {
|
|
|
|
|
|
|
|
image_url: ImageUrl { url },
|
|
|
|
|
|
|
|
} => {
|
|
|
|
|
|
|
|
if let Some((mime_type, data)) = url
|
|
|
|
|
|
|
|
.strip_prefix("data:")
|
|
|
|
|
|
|
|
.and_then(|v| v.split_once(";base64,"))
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
json!({
|
|
|
|
|
|
|
|
"image": {
|
|
|
|
|
|
|
|
"format": mime_type.replace("image/", ""),
|
|
|
|
|
|
|
|
"source": {
|
|
|
|
|
|
|
|
"bytes": data,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
network_image_urls.push(url.clone());
|
|
|
|
|
|
|
|
json!({ "url": url })
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
vec![json!({
|
|
|
|
|
|
|
|
"role": role,
|
|
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
|
|
})]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MessageContent::ToolResults((tool_results, text)) => {
|
|
|
|
|
|
|
|
let mut assistant_parts = vec![];
|
|
|
|
|
|
|
|
let mut user_parts = vec![];
|
|
|
|
|
|
|
|
if !text.is_empty() {
|
|
|
|
|
|
|
|
assistant_parts.push(json!({
|
|
|
|
|
|
|
|
"text": text,
|
|
|
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for tool_result in tool_results {
|
|
|
|
|
|
|
|
assistant_parts.push(json!({
|
|
|
|
|
|
|
|
"toolUse": {
|
|
|
|
|
|
|
|
"toolUseId": tool_result.call.id,
|
|
|
|
|
|
|
|
"name": tool_result.call.name,
|
|
|
|
|
|
|
|
"input": tool_result.call.arguments,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
user_parts.push(json!({
|
|
|
|
|
|
|
|
"toolResult": {
|
|
|
|
|
|
|
|
"toolUseId": tool_result.call.id,
|
|
|
|
|
|
|
|
"content": [
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"json": tool_result.output,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
vec![
|
|
|
|
|
|
|
|
json!({
|
|
|
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
|
|
|
"content": assistant_parts,
|
|
|
|
|
|
|
|
}),
|
|
|
|
|
|
|
|
json!({
|
|
|
|
|
|
|
|
"role": "user",
|
|
|
|
|
|
|
|
"content": user_parts,
|
|
|
|
|
|
|
|
}),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if !network_image_urls.is_empty() {
|
|
|
|
|
|
|
|
bail!(
|
|
|
|
|
|
|
|
"The model does not support network images: {:?}",
|
|
|
|
|
|
|
|
network_image_urls
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut body = json!({
|
|
|
|
|
|
|
|
"inferenceConfig": {},
|
|
|
|
|
|
|
|
"messages": messages,
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
if let Some(v) = system_message {
|
|
|
|
|
|
|
|
body["system"] = json!([
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"text": v,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(v) = model.max_tokens_param() {
|
|
|
|
if let Some(v) = model.max_tokens_param() {
|
|
|
|
body["max_tokens"] = v.into();
|
|
|
|
body["inferenceConfig"]["maxTokens"] = v.into();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if let Some(v) = temperature {
|
|
|
|
if let Some(v) = temperature {
|
|
|
|
body["temperature"] = v.into();
|
|
|
|
body["inferenceConfig"]["temperature"] = v.into();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if let Some(v) = top_p {
|
|
|
|
if let Some(v) = top_p {
|
|
|
|
body["top_p"] = v.into();
|
|
|
|
body["inferenceConfig"]["topP"] = v.into();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(functions) = functions {
|
|
|
|
|
|
|
|
let tools: Vec<_> = functions
|
|
|
|
|
|
|
|
.iter()
|
|
|
|
|
|
|
|
.map(|v| {
|
|
|
|
|
|
|
|
json!({
|
|
|
|
|
|
|
|
"toolSpec": {
|
|
|
|
|
|
|
|
"name": v.name,
|
|
|
|
|
|
|
|
"description": v.description,
|
|
|
|
|
|
|
|
"inputSchema": {
|
|
|
|
|
|
|
|
"json": v.parameters,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
body["toolConfig"] = json!({
|
|
|
|
|
|
|
|
"tools": tools,
|
|
|
|
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Ok(body)
|
|
|
|
Ok(body)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn llama_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
let text = data["generation"]
|
|
|
|
let mut texts = vec![];
|
|
|
|
.as_str()
|
|
|
|
let mut tool_calls = vec![];
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
if let Some(array) = data["output"]["message"]["content"].as_array() {
|
|
|
|
|
|
|
|
for item in array {
|
|
|
|
|
|
|
|
if let Some(text) = item["text"].as_str() {
|
|
|
|
|
|
|
|
texts.push(text);
|
|
|
|
|
|
|
|
} else if let Some(tool_use) = item["toolUse"].as_object() {
|
|
|
|
|
|
|
|
if let (Some(id), Some(name), Some(input)) = (
|
|
|
|
|
|
|
|
json_str_from_map(tool_use, "toolUseId"),
|
|
|
|
|
|
|
|
json_str_from_map(tool_use, "name"),
|
|
|
|
|
|
|
|
tool_use.get("input"),
|
|
|
|
|
|
|
|
) {
|
|
|
|
|
|
|
|
tool_calls.push(ToolCall::new(
|
|
|
|
|
|
|
|
name.to_string(),
|
|
|
|
|
|
|
|
input.clone(),
|
|
|
|
|
|
|
|
Some(id.to_string()),
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if texts.is_empty() && tool_calls.is_empty() {
|
|
|
|
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
let output = ChatCompletionsOutput {
|
|
|
|
text: text.to_string(),
|
|
|
|
text: texts.join("\n\n"),
|
|
|
|
tool_calls: vec![],
|
|
|
|
tool_calls,
|
|
|
|
id: None,
|
|
|
|
id: None,
|
|
|
|
input_tokens: data["prompt_token_count"].as_u64(),
|
|
|
|
input_tokens: data["usage"]["inputTokens"].as_u64(),
|
|
|
|
output_tokens: data["generation_token_count"].as_u64(),
|
|
|
|
output_tokens: data["usage"]["outputTokens"].as_u64(),
|
|
|
|
};
|
|
|
|
};
|
|
|
|
Ok(output)
|
|
|
|
Ok(output)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn mistral_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
|
|
|
let text = data["outputs"][0]["text"]
|
|
|
|
|
|
|
|
.as_str()
|
|
|
|
|
|
|
|
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
|
|
|
|
|
|
|
|
Ok(ChatCompletionsOutput::new(text))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
|
|
|
|
|
|
enum ModelCategory {
|
|
|
|
|
|
|
|
Anthropic,
|
|
|
|
|
|
|
|
MetaLlama3,
|
|
|
|
|
|
|
|
Mistral,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl FromStr for ModelCategory {
|
|
|
|
|
|
|
|
type Err = anyhow::Error;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
|
|
|
|
|
|
|
if s.starts_with("anthropic.") {
|
|
|
|
|
|
|
|
Ok(ModelCategory::Anthropic)
|
|
|
|
|
|
|
|
} else if s.starts_with("meta.llama3") {
|
|
|
|
|
|
|
|
Ok(ModelCategory::MetaLlama3)
|
|
|
|
|
|
|
|
} else if s.starts_with("mistral") {
|
|
|
|
|
|
|
|
Ok(ModelCategory::Mistral)
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
unsupported_model!(s)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
#[derive(Debug)]
|
|
|
|
struct AwsCredentials {
|
|
|
|
struct AwsCredentials {
|
|
|
|
access_key_id: String,
|
|
|
|
access_key_id: String,
|
|
|
@ -439,10 +570,3 @@ fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) ->
|
|
|
|
let k_service = hmac_sha256(&k_region, service);
|
|
|
|
let k_service = hmac_sha256(&k_region, service);
|
|
|
|
hmac_sha256(&k_service, "aws4_request")
|
|
|
|
hmac_sha256(&k_service, "aws4_request")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn decode_chunk(data: &[u8]) -> Option<Value> {
|
|
|
|
|
|
|
|
let data = serde_json::from_slice::<Value>(data).ok()?;
|
|
|
|
|
|
|
|
let data = data["bytes"].as_str()?;
|
|
|
|
|
|
|
|
let data = base64_decode(data).ok()?;
|
|
|
|
|
|
|
|
serde_json::from_slice(&data).ok()
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|