feat: bedrock client switch to converse api and support cohere models (#747)

pull/756/head
sigoden 3 months ago committed by GitHub
parent 2eed63f014
commit adf6716c84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -368,8 +368,9 @@
# docs: # docs:
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns # - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
# - https://aws.amazon.com/bedrock/pricing/ # - https://aws.amazon.com/bedrock/pricing/
# - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
# notes: # notes:
# - get max_output_tokens info from playground # - except for Claude, other models do not support streaming function calling
models: models:
- name: anthropic.claude-3-5-sonnet-20240620-v1:0 - name: anthropic.claude-3-5-sonnet-20240620-v1:0
max_input_tokens: 200000 max_input_tokens: 200000
@ -405,44 +406,56 @@
supports_function_calling: true supports_function_calling: true
- name: meta.llama3-1-405b-instruct-v1:0 - name: meta.llama3-1-405b-instruct-v1:0
max_input_tokens: 128000 max_input_tokens: 128000
max_output_tokens: 2048 input_price: 5.32
require_max_tokens: true output_price: 16
supports_function_calling: true
- name: meta.llama3-1-70b-instruct-v1:0 - name: meta.llama3-1-70b-instruct-v1:0
max_input_tokens: 128000 max_input_tokens: 128000
max_output_tokens: 2048
require_max_tokens: true
input_price: 2.65 input_price: 2.65
output_price: 3.5 output_price: 3.5
supports_function_calling: true
- name: meta.llama3-1-8b-instruct-v1:0 - name: meta.llama3-1-8b-instruct-v1:0
max_input_tokens: 128000 max_input_tokens: 128000
max_output_tokens: 2048
require_max_tokens: true
input_price: 0.3 input_price: 0.3
output_price: 0.6 output_price: 0.6
supports_function_calling: true
- name: meta.llama3-70b-instruct-v1:0 - name: meta.llama3-70b-instruct-v1:0
max_input_tokens: 8192 max_input_tokens: 8192
max_output_tokens: 2048
require_max_tokens: true
input_price: 2.65 input_price: 2.65
output_price: 3.5 output_price: 3.5
- name: meta.llama3-8b-instruct-v1:0 - name: meta.llama3-8b-instruct-v1:0
max_input_tokens: 8192 max_input_tokens: 8192
max_output_tokens: 2048
require_max_tokens: true
input_price: 0.3 input_price: 0.3
output_price: 0.6 output_price: 0.6
- name: mistral.mistral-large-2402-v1:0 - name: mistral.mistral-large-2407-v1:0
max_input_tokens: 32000 max_input_tokens: 128000
max_output_tokens: 8192 input_price: 3
require_max_tokens: true output_price: 9
input_price: 8 supports_function_calling: true
output_price: 2.4 - name: cohere.command-r-plus-v1:0
- name: mistral.mixtral-8x7b-instruct-v0:1 max_input_tokens: 128000
max_input_tokens: 32000 input_price: 3
max_output_tokens: 8192 output_price: 15
require_max_tokens: true supports_function_calling: true
input_price: 0.45 - name: cohere.command-r-v1:0
output_price: 0.7 max_input_tokens: 128000
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: cohere.embed-english-v3
type: embedding
max_input_tokens: 512
input_price: 0.1
output_vector_size: 1024
default_chunk_size: 1000
max_batch_size: 96
- name: cohere.embed-multilingual-v3
type: embedding
max_input_tokens: 512
input_price: 0.1
output_vector_size: 1024
default_chunk_size: 1000
max_batch_size: 96
- platform: cloudflare - platform: cloudflare
# docs: # docs:

@ -1,11 +1,8 @@
use super::claude::*;
use super::prompt_format::*;
use super::*; use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
use anyhow::{anyhow, bail, Result}; use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers; use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut; use bytes::BytesMut;
@ -32,32 +29,6 @@ pub struct BedrockConfig {
pub extra: Option<ExtraConfig>, pub extra: Option<ExtraConfig>,
} }
#[async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.chat_completions_builder(client, data, &model_category)?;
chat_completions(builder, &model_category).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.chat_completions_builder(client, data, &model_category)?;
chat_completions_streaming(builder, handler, &model_category).await
}
}
impl BedrockClient { impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id); config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key); config_get_fn!(secret_access_key, get_secret_access_key);
@ -83,23 +54,22 @@ impl BedrockClient {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
data: ChatCompletionsData, data: ChatCompletionsData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> { ) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?; let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?; let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?; let region = self.get_region()?;
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let model_name = &self.model.name(); let model_name = &self.model.name();
let uri = if data.stream { let uri = if data.stream {
format!("/model/{model_name}/invoke-with-response-stream") format!("/model/{model_name}/converse-stream")
} else { } else {
format!("/model/{model_name}/invoke") format!("/model/{model_name}/converse")
}; };
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let headers = IndexMap::new(); let headers = IndexMap::new();
let mut body = build_chat_completions_body(data, &self.model, model_category)?; let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_chat_completions_body(&mut body); self.patch_chat_completions_body(&mut body);
let builder = aws_fetch( let builder = aws_fetch(
@ -122,12 +92,61 @@ impl BedrockClient {
Ok(builder) Ok(builder)
} }
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let uri = format!("/model/{}/invoke", self.model.name());
let headers = IndexMap::new();
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"texts": data.texts,
"input_type": input_type,
});
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
} }
async fn chat_completions( impl_client_trait!(
builder: RequestBuilder, BedrockClient,
model_category: &ModelCategory, chat_completions,
) -> Result<ChatCompletionsOutput> { chat_completions_streaming,
embeddings
);
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -137,17 +156,12 @@ async fn chat_completions(
} }
debug!("non-stream-data: {data}"); debug!("non-stream-data: {data}");
match model_category { extract_chat_completions(&data)
ModelCategory::Anthropic => claude_extract_chat_completions(&data),
ModelCategory::MetaLlama3 => llama_extract_chat_completions(&data),
ModelCategory::Mistral => mistral_extract_chat_completions(&data),
}
} }
async fn chat_completions_streaming( async fn chat_completions_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut SseHandler, handler: &mut SseHandler,
model_category: &ModelCategory,
) -> Result<()> { ) -> Result<()> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
@ -156,6 +170,11 @@ async fn chat_completions_streaming(
catch_error(&data, status.as_u16())?; catch_error(&data, status.as_u16())?;
bail!("Invalid response data: {data}"); bail!("Invalid response data: {data}");
} }
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut stream = res.bytes_stream(); let mut stream = res.bytes_stream();
let mut buffer = BytesMut::new(); let mut buffer = BytesMut::new();
let mut decoder = MessageFrameDecoder::new(); let mut decoder = MessageFrameDecoder::new();
@ -167,31 +186,53 @@ async fn chat_completions_streaming(
let message_type = response_headers.message_type.as_str(); let message_type = response_headers.message_type.as_str();
let smithy_type = response_headers.smithy_type.as_str(); let smithy_type = response_headers.smithy_type.as_str();
match (message_type, smithy_type) { match (message_type, smithy_type) {
("event", "chunk") => { ("event", _) => {
let data: Value = decode_chunk(message.payload()).ok_or_else(|| { let data: Value = serde_json::from_slice(message.payload())?;
anyhow!("Invalid chunk data: {}", hex_encode(message.payload())) debug!("stream-data: {smithy_type} {data}");
})?; match smithy_type {
debug!("stream-data: {data}"); "contentBlockStart" => {
match model_category { if let Some(tool_use) = data["start"]["toolUse"].as_object() {
ModelCategory::Anthropic => { if let (Some(id), Some(name)) = (
if let Some(typ) = data["type"].as_str() { json_str_from_map(tool_use, "toolUseId"),
if typ == "content_block_delta" { json_str_from_map(tool_use, "name"),
if let Some(text) = data["delta"]["text"].as_str() { ) {
handler.text(text)?; if !function_name.is_empty() {
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
} }
function_arguments.clear();
function_name = name.into();
function_id = id.into();
} }
} }
} }
ModelCategory::MetaLlama3 => { "contentBlockDelta" => {
if let Some(text) = data["generation"].as_str() { if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?; handler.text(text)?;
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
function_arguments.push_str(input);
} }
} }
ModelCategory::Mistral => { "contentBlockStop" => {
if let Some(text) = data["outputs"][0]["text"].as_str() { if !function_name.is_empty() {
handler.text(text)?; let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
} }
} }
_ => {}
} }
} }
("exception", _) => { ("exception", _) => {
@ -209,124 +250,214 @@ async fn chat_completions_streaming(
Ok(()) Ok(())
} }
fn build_chat_completions_body( async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
data: ChatCompletionsData, let res = builder.send().await?;
model: &Model, let status = res.status();
model_category: &ModelCategory, let data: Value = res.json().await?;
) -> Result<Value> {
match model_category { if !status.is_success() {
ModelCategory::Anthropic => { catch_error(&data, status.as_u16())?;
let mut body = claude_build_chat_completions_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
body_obj.remove("stream");
}
body["anthropic_version"] = "bedrock-2023-05-31".into();
Ok(body)
}
ModelCategory::MetaLlama3 => {
meta_llama_build_chat_completions_body(data, model, LLAMA3_PROMPT_FORMAT)
}
ModelCategory::Mistral => mistral_build_chat_completions_body(data, model),
} }
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<Vec<f32>>,
} }
fn meta_llama_build_chat_completions_body( fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
data: ChatCompletionsData,
model: &Model,
pt: PromptFormat,
) -> Result<Value> {
let ChatCompletionsData { let ChatCompletionsData {
messages, mut messages,
temperature, temperature,
top_p, top_p,
functions: _, functions,
stream: _, stream: _,
} = data; } = data;
let prompt = generate_prompt(&messages, pt)?;
let mut body = json!({ "prompt": prompt });
if let Some(v) = model.max_tokens_param() { let system_message = extract_system_message(&mut messages);
body["max_gen_len"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
Ok(body) let mut network_image_urls = vec![];
}
fn mistral_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> { let messages: Vec<Value> = messages
let ChatCompletionsData { .into_iter()
messages, .flat_map(|message| {
temperature, let Message { role, content } = message;
top_p, match content {
functions: _, MessageContent::Text(text) => vec![json!({
stream: _, "role": role,
} = data; "content": [
let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?; {
let mut body = json!({ "prompt": prompt }); "text": text,
}
],
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"text": text})
}
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"image": {
"format": mime_type.replace("image/", ""),
"source": {
"bytes": data,
}
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolResults((tool_results, text)) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
assistant_parts.push(json!({
"text": text,
}))
}
for tool_result in tool_results {
assistant_parts.push(json!({
"toolUse": {
"toolUseId": tool_result.call.id,
"name": tool_result.call.name,
"input": tool_result.call.arguments,
}
}));
user_parts.push(json!({
"toolResult": {
"toolUseId": tool_result.call.id,
"content": [
{
"json": tool_result.output,
}
]
}
}));
}
vec![
json!({
"role": "assistant",
"content": assistant_parts,
}),
json!({
"role": "user",
"content": user_parts,
}),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"inferenceConfig": {},
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = json!([
{
"text": v,
}
])
}
if let Some(v) = model.max_tokens_param() { if let Some(v) = model.max_tokens_param() {
body["max_tokens"] = v.into(); body["inferenceConfig"]["maxTokens"] = v.into();
} }
if let Some(v) = temperature { if let Some(v) = temperature {
body["temperature"] = v.into(); body["inferenceConfig"]["temperature"] = v.into();
} }
if let Some(v) = top_p { if let Some(v) = top_p {
body["top_p"] = v.into(); body["inferenceConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
let tools: Vec<_> = functions
.iter()
.map(|v| {
json!({
"toolSpec": {
"name": v.name,
"description": v.description,
"inputSchema": {
"json": v.parameters,
},
}
})
})
.collect();
body["toolConfig"] = json!({
"tools": tools,
})
} }
Ok(body) Ok(body)
} }
fn llama_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> { fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["generation"] let mut texts = vec![];
.as_str() let mut tool_calls = vec![];
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; if let Some(array) = data["output"]["message"]["content"].as_array() {
for item in array {
if let Some(text) = item["text"].as_str() {
texts.push(text);
} else if let Some(tool_use) = item["toolUse"].as_object() {
if let (Some(id), Some(name), Some(input)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
tool_use.get("input"),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
}
}
}
}
if texts.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput {
text: text.to_string(), text: texts.join("\n\n"),
tool_calls: vec![], tool_calls,
id: None, id: None,
input_tokens: data["prompt_token_count"].as_u64(), input_tokens: data["usage"]["inputTokens"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(), output_tokens: data["usage"]["outputTokens"].as_u64(),
}; };
Ok(output) Ok(output)
} }
fn mistral_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(ChatCompletionsOutput::new(text))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Anthropic,
MetaLlama3,
Mistral,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("anthropic.") {
Ok(ModelCategory::Anthropic)
} else if s.starts_with("meta.llama3") {
Ok(ModelCategory::MetaLlama3)
} else if s.starts_with("mistral") {
Ok(ModelCategory::Mistral)
} else {
unsupported_model!(s)
}
}
}
#[derive(Debug)] #[derive(Debug)]
struct AwsCredentials { struct AwsCredentials {
access_key_id: String, access_key_id: String,
@ -439,10 +570,3 @@ fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) ->
let k_service = hmac_sha256(&k_region, service); let k_service = hmac_sha256(&k_region, service);
hmac_sha256(&k_service, "aws4_request") hmac_sha256(&k_service, "aws4_request")
} }
fn decode_chunk(data: &[u8]) -> Option<Value> {
let data = serde_json::from_slice::<Value>(data).ok()?;
let data = data["bytes"].as_str()?;
let data = base64_decode(data).ok()?;
serde_json::from_slice(&data).ok()
}

Loading…
Cancel
Save