bedrock client supports cohere models

pull/747/head
sigoden 3 months ago
parent e5a9cd90ca
commit a4f68bd2fe

@ -426,6 +426,30 @@
max_input_tokens: 128000
input_price: 3
output_price: 9
- name: cohere.command-r-plus-v1:0
max_input_tokens: 128000
input_price: 3
output_price: 15
supports_function_calling: true
- name: cohere.command-r-v1:0
max_input_tokens: 128000
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: cohere.embed-english-v3
type: embedding
max_input_tokens: 512
input_price: 0.1
output_vector_size: 1024
default_chunk_size: 1000
max_batch_size: 96
- name: cohere.embed-multilingual-v3
type: embedding
max_input_tokens: 512
input_price: 0.1
output_vector_size: 1024
default_chunk_size: 1000
max_batch_size: 96
- platform: cloudflare
# docs:

@ -3,7 +3,6 @@ use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
@ -30,30 +29,6 @@ pub struct BedrockConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}
}
impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key);
@ -83,6 +58,7 @@ impl BedrockClient {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let model_name = &self.model.name();
let uri = if data.stream {
@ -90,7 +66,6 @@ impl BedrockClient {
} else {
format!("/model/{model_name}/converse")
};
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let headers = IndexMap::new();
@ -117,8 +92,60 @@ impl BedrockClient {
Ok(builder)
}
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let uri = format!("/model/{}/invoke", self.model.name());
let headers = IndexMap::new();
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"texts": data.texts,
"input_type": input_type,
});
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
}
impl_client_trait!(
BedrockClient,
chat_completions,
chat_completions_streaming,
embeddings
);
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
@ -223,6 +250,25 @@ async fn chat_completions_streaming(
Ok(())
}
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<Vec<f32>>,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
mut messages,

Loading…
Cancel
Save