diff --git a/models.yaml b/models.yaml index a2bdfae..522e4cd 100644 --- a/models.yaml +++ b/models.yaml @@ -368,8 +368,7 @@ # docs: # - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns # - https://aws.amazon.com/bedrock/pricing/ - # notes: - # - get max_output_tokens info from playground + # - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html models: - name: anthropic.claude-3-5-sonnet-20240620-v1:0 max_input_tokens: 200000 @@ -405,44 +404,28 @@ supports_function_calling: true - name: meta.llama3-1-405b-instruct-v1:0 max_input_tokens: 128000 - max_output_tokens: 2048 - require_max_tokens: true + input_price: 5.32 + output_price: 16 - name: meta.llama3-1-70b-instruct-v1:0 max_input_tokens: 128000 - max_output_tokens: 2048 - require_max_tokens: true input_price: 2.65 output_price: 3.5 - name: meta.llama3-1-8b-instruct-v1:0 max_input_tokens: 128000 - max_output_tokens: 2048 - require_max_tokens: true input_price: 0.3 output_price: 0.6 - name: meta.llama3-70b-instruct-v1:0 max_input_tokens: 8192 - max_output_tokens: 2048 - require_max_tokens: true input_price: 2.65 output_price: 3.5 - name: meta.llama3-8b-instruct-v1:0 max_input_tokens: 8192 - max_output_tokens: 2048 - require_max_tokens: true input_price: 0.3 output_price: 0.6 - - name: mistral.mistral-large-2402-v1:0 - max_input_tokens: 32000 - max_output_tokens: 8192 - require_max_tokens: true - input_price: 8 - output_price: 2.4 - - name: mistral.mixtral-8x7b-instruct-v0:1 - max_input_tokens: 32000 - max_output_tokens: 8192 - require_max_tokens: true - input_price: 0.45 - output_price: 0.7 + - name: mistral.mistral-large-2407-v1:0 + max_input_tokens: 128000 + input_price: 3 + output_price: 9 - platform: cloudflare # docs: diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 5d59ffe..fca76b3 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -1,10 +1,8 @@ -use super::claude::*; -use super::prompt_format::*; use super::*; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Context, Result}; use async_trait::async_trait; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; use aws_smithy_eventstream::smithy::parse_response_headers; @@ -41,9 +39,8 @@ impl Client for BedrockClient { client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - let model_category = ModelCategory::from_str(self.model.name())?; - let builder = self.chat_completions_builder(client, data, &model_category)?; - chat_completions(builder, &model_category).await + let builder = self.chat_completions_builder(client, data)?; + chat_completions(builder).await } async fn chat_completions_streaming_inner( @@ -52,9 +49,8 @@ impl Client for BedrockClient { handler: &mut SseHandler, data: ChatCompletionsData, ) -> Result<()> { - let model_category = ModelCategory::from_str(self.model.name())?; - let builder = self.chat_completions_builder(client, data, &model_category)?; - chat_completions_streaming(builder, handler, &model_category).await + let builder = self.chat_completions_builder(client, data)?; + chat_completions_streaming(builder, handler).await } } @@ -83,7 +79,6 @@ impl BedrockClient { &self, client: &ReqwestClient, data: ChatCompletionsData, - model_category: &ModelCategory, ) -> Result { let access_key_id = self.get_access_key_id()?; let secret_access_key = self.get_secret_access_key()?; @@ -91,15 +86,15 @@ impl BedrockClient { let model_name = &self.model.name(); let uri = if data.stream { - format!("/model/{model_name}/invoke-with-response-stream") + format!("/model/{model_name}/converse-stream") } else { - format!("/model/{model_name}/invoke") + format!("/model/{model_name}/converse") }; let host = format!("bedrock-runtime.{region}.amazonaws.com"); let headers = IndexMap::new(); - let mut body = build_chat_completions_body(data, &self.model, model_category)?; + let mut body = build_chat_completions_body(data, &self.model)?; self.patch_chat_completions_body(&mut body); let builder = aws_fetch( @@ -124,10 +119,7 @@ impl BedrockClient { } } -async fn chat_completions( - builder: RequestBuilder, - model_category: &ModelCategory, -) -> Result { +async fn chat_completions(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -137,17 +129,12 @@ async fn chat_completions( } debug!("non-stream-data: {data}"); - match model_category { - ModelCategory::Anthropic => claude_extract_chat_completions(&data), - ModelCategory::MetaLlama3 => llama_extract_chat_completions(&data), - ModelCategory::Mistral => mistral_extract_chat_completions(&data), - } + extract_chat_completions(&data) } async fn chat_completions_streaming( builder: RequestBuilder, handler: &mut SseHandler, - model_category: &ModelCategory, ) -> Result<()> { let res = builder.send().await?; let status = res.status(); @@ -156,6 +143,11 @@ async fn chat_completions_streaming( catch_error(&data, status.as_u16())?; bail!("Invalid response data: {data}"); } + + let mut function_name = String::new(); + let mut function_arguments = String::new(); + let mut function_id = String::new(); + let mut stream = res.bytes_stream(); let mut buffer = BytesMut::new(); let mut decoder = MessageFrameDecoder::new(); @@ -167,31 +159,53 @@ async fn chat_completions_streaming( let message_type = response_headers.message_type.as_str(); let smithy_type = response_headers.smithy_type.as_str(); match (message_type, smithy_type) { - ("event", "chunk") => { - let data: Value = decode_chunk(message.payload()).ok_or_else(|| { - anyhow!("Invalid chunk data: {}", hex_encode(message.payload())) - })?; - debug!("stream-data: {data}"); - match model_category { - ModelCategory::Anthropic => { - if let Some(typ) = data["type"].as_str() { - if typ == "content_block_delta" { - if let Some(text) = data["delta"]["text"].as_str() { - handler.text(text)?; + ("event", _) => { + let data: Value = serde_json::from_slice(message.payload())?; + debug!("stream-data: {smithy_type} {data}"); + match smithy_type { + "contentBlockStart" => { + if let Some(tool_use) = data["start"]["toolUse"].as_object() { + if let (Some(id), Some(name)) = ( + json_str_from_map(tool_use, "toolUseId"), + json_str_from_map(tool_use, "name"), + ) { + if !function_name.is_empty() { + let arguments: Value = + function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; } + function_arguments.clear(); + function_name = name.into(); + function_id = id.into(); } } } - ModelCategory::MetaLlama3 => { - if let Some(text) = data["generation"].as_str() { + "contentBlockDelta" => { + if let Some(text) = data["delta"]["text"].as_str() { handler.text(text)?; + } else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() { + function_arguments.push_str(input); } } - ModelCategory::Mistral => { - if let Some(text) = data["outputs"][0]["text"].as_str() { - handler.text(text)?; + "contentBlockStop" => { + if !function_name.is_empty() { + let arguments: Value = function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; } } + _ => {} } } ("exception", _) => { @@ -209,124 +223,195 @@ async fn chat_completions_streaming( Ok(()) } -fn build_chat_completions_body( - data: ChatCompletionsData, - model: &Model, - model_category: &ModelCategory, -) -> Result { - match model_category { - ModelCategory::Anthropic => { - let mut body = claude_build_chat_completions_body(data, model)?; - if let Some(body_obj) = body.as_object_mut() { - body_obj.remove("model"); - body_obj.remove("stream"); - } - body["anthropic_version"] = "bedrock-2023-05-31".into(); - Ok(body) - } - ModelCategory::MetaLlama3 => { - meta_llama_build_chat_completions_body(data, model, LLAMA3_PROMPT_FORMAT) - } - ModelCategory::Mistral => mistral_build_chat_completions_body(data, model), - } -} - -fn meta_llama_build_chat_completions_body( - data: ChatCompletionsData, - model: &Model, - pt: PromptFormat, -) -> Result { +fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { let ChatCompletionsData { - messages, + mut messages, temperature, top_p, - functions: _, + functions, stream: _, } = data; - let prompt = generate_prompt(&messages, pt)?; - let mut body = json!({ "prompt": prompt }); - if let Some(v) = model.max_tokens_param() { - body["max_gen_len"] = v.into(); - } - if let Some(v) = temperature { - body["temperature"] = v.into(); - } - if let Some(v) = top_p { - body["top_p"] = v.into(); - } + let system_message = extract_system_message(&mut messages); - Ok(body) -} + let mut network_image_urls = vec![]; -fn mistral_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { - let ChatCompletionsData { - messages, - temperature, - top_p, - functions: _, - stream: _, - } = data; - let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?; - let mut body = json!({ "prompt": prompt }); + let messages: Vec = messages + .into_iter() + .flat_map(|message| { + let Message { role, content } = message; + match content { + MessageContent::Text(text) => vec![json!({ + "role": role, + "content": [ + { + "text": text, + } + ], + })], + MessageContent::Array(list) => { + let content: Vec<_> = list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => { + json!({"text": text}) + } + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + if let Some((mime_type, data)) = url + .strip_prefix("data:") + .and_then(|v| v.split_once(";base64,")) + { + json!({ + "image": { + "format": mime_type.replace("image/", ""), + "source": { + "bytes": data, + } + } + }) + } else { + network_image_urls.push(url.clone()); + json!({ "url": url }) + } + } + }) + .collect(); + vec![json!({ + "role": role, + "content": content, + })] + } + MessageContent::ToolResults((tool_results, text)) => { + let mut assistant_parts = vec![]; + let mut user_parts = vec![]; + if !text.is_empty() { + assistant_parts.push(json!({ + "text": text, + })) + } + for tool_result in tool_results { + assistant_parts.push(json!({ + "toolUse": { + "toolUseId": tool_result.call.id, + "name": tool_result.call.name, + "input": tool_result.call.arguments, + } + })); + user_parts.push(json!({ + "toolResult": { + "toolUseId": tool_result.call.id, + "content": [ + { + "json": tool_result.output, + } + ] + } + })); + } + vec![ + json!({ + "role": "assistant", + "content": assistant_parts, + }), + json!({ + "role": "user", + "content": user_parts, + }), + ] + } + } + }) + .collect(); + + if !network_image_urls.is_empty() { + bail!( + "The model does not support network images: {:?}", + network_image_urls + ); + } + + let mut body = json!({ + "inferenceConfig": {}, + "messages": messages, + }); + if let Some(v) = system_message { + body["system"] = json!([ + { + "text": v, + } + ]) + } if let Some(v) = model.max_tokens_param() { - body["max_tokens"] = v.into(); + body["inferenceConfig"]["maxTokens"] = v.into(); } if let Some(v) = temperature { - body["temperature"] = v.into(); + body["inferenceConfig"]["temperature"] = v.into(); } if let Some(v) = top_p { - body["top_p"] = v.into(); + body["inferenceConfig"]["topP"] = v.into(); + } + if let Some(functions) = functions { + let tools: Vec<_> = functions + .iter() + .map(|v| { + json!({ + "toolSpec": { + "name": v.name, + "description": v.description, + "inputSchema": { + "json": v.parameters, + }, + } + }) + }) + .collect(); + body["toolConfig"] = json!({ + "tools": tools, + }) } - Ok(body) } -fn llama_extract_chat_completions(data: &Value) -> Result { - let text = data["generation"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; +fn extract_chat_completions(data: &Value) -> Result { + let mut texts = vec![]; + let mut tool_calls = vec![]; + if let Some(array) = data["output"]["message"]["content"].as_array() { + for item in array { + if let Some(text) = item["text"].as_str() { + texts.push(text); + } else if let Some(tool_use) = item["toolUse"].as_object() { + if let (Some(id), Some(name), Some(input)) = ( + json_str_from_map(tool_use, "toolUseId"), + json_str_from_map(tool_use, "name"), + tool_use.get("input"), + ) { + tool_calls.push(ToolCall::new( + name.to_string(), + input.clone(), + Some(id.to_string()), + )) + } + } + } + } + + if texts.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + let output = ChatCompletionsOutput { - text: text.to_string(), - tool_calls: vec![], + text: texts.join("\n\n"), + tool_calls, id: None, - input_tokens: data["prompt_token_count"].as_u64(), - output_tokens: data["generation_token_count"].as_u64(), + input_tokens: data["usage"]["inputTokens"].as_u64(), + output_tokens: data["usage"]["outputTokens"].as_u64(), }; Ok(output) } -fn mistral_extract_chat_completions(data: &Value) -> Result { - let text = data["outputs"][0]["text"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - Ok(ChatCompletionsOutput::new(text)) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ModelCategory { - Anthropic, - MetaLlama3, - Mistral, -} - -impl FromStr for ModelCategory { - type Err = anyhow::Error; - - fn from_str(s: &str) -> std::result::Result { - if s.starts_with("anthropic.") { - Ok(ModelCategory::Anthropic) - } else if s.starts_with("meta.llama3") { - Ok(ModelCategory::MetaLlama3) - } else if s.starts_with("mistral") { - Ok(ModelCategory::Mistral) - } else { - unsupported_model!(s) - } - } -} - #[derive(Debug)] struct AwsCredentials { access_key_id: String, @@ -439,10 +524,3 @@ fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> let k_service = hmac_sha256(&k_region, service); hmac_sha256(&k_service, "aws4_request") } - -fn decode_chunk(data: &[u8]) -> Option { - let data = serde_json::from_slice::(data).ok()?; - let data = data["bytes"].as_str()?; - let data = base64_decode(data).ok()?; - serde_json::from_slice(&data).ok() -}