From a4f68bd2fe9bc20b8f9b0fd4ca9093b5c7f07527 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 26 Jul 2024 03:34:48 +0000 Subject: [PATCH] bedrock client supports cohere models --- models.yaml | 24 +++++++++++ src/client/bedrock.rs | 98 +++++++++++++++++++++++++++++++------------ 2 files changed, 96 insertions(+), 26 deletions(-) diff --git a/models.yaml b/models.yaml index 522e4cd..b79fafc 100644 --- a/models.yaml +++ b/models.yaml @@ -426,6 +426,30 @@ max_input_tokens: 128000 input_price: 3 output_price: 9 + - name: cohere.command-r-plus-v1:0 + max_input_tokens: 128000 + input_price: 3 + output_price: 15 + supports_function_calling: true + - name: cohere.command-r-v1:0 + max_input_tokens: 128000 + input_price: 0.5 + output_price: 1.5 + supports_function_calling: true + - name: cohere.embed-english-v3 + type: embedding + max_input_tokens: 512 + input_price: 0.1 + output_vector_size: 1024 + default_chunk_size: 1000 + max_batch_size: 96 + - name: cohere.embed-multilingual-v3 + type: embedding + max_input_tokens: 512 + input_price: 0.1 + output_vector_size: 1024 + default_chunk_size: 1000 + max_batch_size: 96 - platform: cloudflare # docs: diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index fca76b3..be418ed 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -3,7 +3,6 @@ use super::*; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; use anyhow::{bail, Context, Result}; -use async_trait::async_trait; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; use aws_smithy_eventstream::smithy::parse_response_headers; use bytes::BytesMut; @@ -30,30 +29,6 @@ pub struct BedrockConfig { pub extra: Option, } -#[async_trait] -impl Client for BedrockClient { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &ReqwestClient, - data: ChatCompletionsData, - ) -> Result { - let builder = self.chat_completions_builder(client, data)?; - chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut SseHandler, - data: ChatCompletionsData, - ) -> Result<()> { - let builder = self.chat_completions_builder(client, data)?; - chat_completions_streaming(builder, handler).await - } -} - impl BedrockClient { config_get_fn!(access_key_id, get_access_key_id); config_get_fn!(secret_access_key, get_secret_access_key); @@ -83,6 +58,7 @@ impl BedrockClient { let access_key_id = self.get_access_key_id()?; let secret_access_key = self.get_secret_access_key()?; let region = self.get_region()?; + let host = format!("bedrock-runtime.{region}.amazonaws.com"); let model_name = &self.model.name(); let uri = if data.stream { @@ -90,7 +66,6 @@ impl BedrockClient { } else { format!("/model/{model_name}/converse") }; - let host = format!("bedrock-runtime.{region}.amazonaws.com"); let headers = IndexMap::new(); @@ -117,8 +92,60 @@ impl BedrockClient { Ok(builder) } + + fn embeddings_builder( + &self, + client: &ReqwestClient, + data: EmbeddingsData, + ) -> Result { + let access_key_id = self.get_access_key_id()?; + let secret_access_key = self.get_secret_access_key()?; + let region = self.get_region()?; + let host = format!("bedrock-runtime.{region}.amazonaws.com"); + + let uri = format!("/model/{}/invoke", self.model.name()); + + let headers = IndexMap::new(); + + let input_type = match data.query { + true => "search_query", + false => "search_document", + }; + + let body = json!({ + "texts": data.texts, + "input_type": input_type, + }); + + let builder = aws_fetch( + client, + &AwsCredentials { + access_key_id, + secret_access_key, + region, + }, + AwsRequest { + method: Method::POST, + host, + service: "bedrock".into(), + uri, + querystring: "".into(), + headers, + body: body.to_string(), + }, + )?; + + Ok(builder) + } } +impl_client_trait!( + BedrockClient, + chat_completions, + chat_completions_streaming, + embeddings +); + async fn chat_completions(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); @@ -223,6 +250,25 @@ async fn chat_completions_streaming( Ok(()) } +async fn embeddings(builder: RequestBuilder) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + Ok(res_body.embeddings) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + embeddings: Vec>, +} + fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { let ChatCompletionsData { mut messages,