|
|
@ -3,7 +3,6 @@ use super::*;
|
|
|
|
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
|
|
|
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
|
|
|
|
|
|
|
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
|
|
|
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
|
|
|
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
|
|
|
use aws_smithy_eventstream::smithy::parse_response_headers;
|
|
|
|
use aws_smithy_eventstream::smithy::parse_response_headers;
|
|
|
|
use bytes::BytesMut;
|
|
|
|
use bytes::BytesMut;
|
|
|
@ -30,30 +29,6 @@ pub struct BedrockConfig {
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
|
|
|
impl Client for BedrockClient {
|
|
|
|
|
|
|
|
client_common_fns!();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_inner(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
|
|
|
) -> Result<ChatCompletionsOutput> {
|
|
|
|
|
|
|
|
let builder = self.chat_completions_builder(client, data)?;
|
|
|
|
|
|
|
|
chat_completions(builder).await
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions_streaming_inner(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
handler: &mut SseHandler,
|
|
|
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
|
|
|
let builder = self.chat_completions_builder(client, data)?;
|
|
|
|
|
|
|
|
chat_completions_streaming(builder, handler).await
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl BedrockClient {
|
|
|
|
impl BedrockClient {
|
|
|
|
config_get_fn!(access_key_id, get_access_key_id);
|
|
|
|
config_get_fn!(access_key_id, get_access_key_id);
|
|
|
|
config_get_fn!(secret_access_key, get_secret_access_key);
|
|
|
|
config_get_fn!(secret_access_key, get_secret_access_key);
|
|
|
@ -83,6 +58,7 @@ impl BedrockClient {
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
let region = self.get_region()?;
|
|
|
|
let region = self.get_region()?;
|
|
|
|
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
let model_name = &self.model.name();
|
|
|
|
let model_name = &self.model.name();
|
|
|
|
let uri = if data.stream {
|
|
|
|
let uri = if data.stream {
|
|
|
@ -90,7 +66,6 @@ impl BedrockClient {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
format!("/model/{model_name}/converse")
|
|
|
|
format!("/model/{model_name}/converse")
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
|
|
|
|
|
|
|
@ -117,8 +92,60 @@ impl BedrockClient {
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
Ok(builder)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn embeddings_builder(
|
|
|
|
|
|
|
|
&self,
|
|
|
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
|
|
|
data: EmbeddingsData,
|
|
|
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
|
|
|
let access_key_id = self.get_access_key_id()?;
|
|
|
|
|
|
|
|
let secret_access_key = self.get_secret_access_key()?;
|
|
|
|
|
|
|
|
let region = self.get_region()?;
|
|
|
|
|
|
|
|
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let uri = format!("/model/{}/invoke", self.model.name());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let headers = IndexMap::new();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let input_type = match data.query {
|
|
|
|
|
|
|
|
true => "search_query",
|
|
|
|
|
|
|
|
false => "search_document",
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let body = json!({
|
|
|
|
|
|
|
|
"texts": data.texts,
|
|
|
|
|
|
|
|
"input_type": input_type,
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let builder = aws_fetch(
|
|
|
|
|
|
|
|
client,
|
|
|
|
|
|
|
|
&AwsCredentials {
|
|
|
|
|
|
|
|
access_key_id,
|
|
|
|
|
|
|
|
secret_access_key,
|
|
|
|
|
|
|
|
region,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
AwsRequest {
|
|
|
|
|
|
|
|
method: Method::POST,
|
|
|
|
|
|
|
|
host,
|
|
|
|
|
|
|
|
service: "bedrock".into(),
|
|
|
|
|
|
|
|
uri,
|
|
|
|
|
|
|
|
querystring: "".into(),
|
|
|
|
|
|
|
|
headers,
|
|
|
|
|
|
|
|
body: body.to_string(),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl_client_trait!(
|
|
|
|
|
|
|
|
BedrockClient,
|
|
|
|
|
|
|
|
chat_completions,
|
|
|
|
|
|
|
|
chat_completions_streaming,
|
|
|
|
|
|
|
|
embeddings
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
|
|
|
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let res = builder.send().await?;
|
|
|
|
let status = res.status();
|
|
|
|
let status = res.status();
|
|
|
@ -223,6 +250,25 @@ async fn chat_completions_streaming(
|
|
|
|
Ok(())
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
|
|
|
|
|
|
|
let res = builder.send().await?;
|
|
|
|
|
|
|
|
let status = res.status();
|
|
|
|
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if !status.is_success() {
|
|
|
|
|
|
|
|
catch_error(&data, status.as_u16())?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let res_body: EmbeddingsResBody =
|
|
|
|
|
|
|
|
serde_json::from_value(data).context("Invalid embeddings data")?;
|
|
|
|
|
|
|
|
Ok(res_body.embeddings)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
|
|
|
struct EmbeddingsResBody {
|
|
|
|
|
|
|
|
embeddings: Vec<Vec<f32>>,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
|
|
|
let ChatCompletionsData {
|
|
|
|
let ChatCompletionsData {
|
|
|
|
mut messages,
|
|
|
|
mut messages,
|
|
|
|