feat: support bedrock client (#450)

pull/452/head
sigoden 2 weeks ago committed by GitHub
parent 615bab215b
commit 1f2b626703
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

90
Cargo.lock generated

@ -35,6 +35,7 @@ dependencies = [
"arboard",
"async-recursion",
"async-trait",
"aws-smithy-eventstream",
"base64 0.22.0",
"bincode",
"bitflags 2.5.0",
@ -46,10 +47,12 @@ dependencies = [
"dirs",
"fancy-regex",
"futures-util",
"hmac",
"http",
"http-body-util",
"hyper",
"hyper-util",
"indexmap",
"inquire",
"is-terminal",
"lazy_static",
@ -74,6 +77,7 @@ dependencies = [
"tokio-graceful",
"tokio-stream",
"unicode-width",
"urlencoding",
]
[[package]]
@ -199,6 +203,35 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]]
name = "aws-smithy-eventstream"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6363078f927f612b970edf9d1903ef5cef9a64d1e8423525ebb1f0a1633c858"
dependencies = [
"aws-smithy-types",
"bytes",
"crc32fast",
]
[[package]]
name = "aws-smithy-types"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729"
dependencies = [
"base64-simd",
"bytes",
"bytes-utils",
"itoa",
"num-integer",
"pin-project-lite",
"pin-utils",
"ryu",
"serde",
"time",
]
[[package]]
name = "backtrace"
version = "0.3.71"
@ -226,6 +259,16 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
[[package]]
name = "base64-simd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195"
dependencies = [
"outref",
"vsimd",
]
[[package]]
name = "bincode"
version = "1.3.3"
@ -315,6 +358,16 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
[[package]]
name = "bytes-utils"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35"
dependencies = [
"bytes",
"either",
]
[[package]]
name = "cc"
version = "1.0.92"
@ -518,6 +571,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
@ -850,6 +904,15 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.9"
@ -1263,6 +1326,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.18"
@ -1379,6 +1451,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "outref"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a"
[[package]]
name = "overload"
version = "0.1.1"
@ -2450,6 +2528,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf8parse"
version = "0.2.1"
@ -2468,6 +2552,12 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "vsimd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
[[package]]
name = "vte"
version = "0.11.1"

@ -52,6 +52,10 @@ http-body-util = "0.1"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["server-auto", "client-legacy"] }
time = { version = "0.3.36", features = ["macros"] }
indexmap = "2.2.6"
hmac = "0.12.1"
aws-smithy-eventstream = "0.60.4"
urlencoding = "2.1.3"
[dependencies.reqwest]
version = "0.12.0"

@ -116,6 +116,12 @@ clients:
# Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
block_threshold: BLOCK_ONLY_HIGH
# See https://docs.aws.amazon.com/bedrock/latest/userguide/
- type: bedrock
access_key_id: xxxxxxxxxxxxxxxxxxxx
secret_access_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
region: xxx
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie
api_key: xxxxxxxxxxxxxxxxxxxxxxxx

@ -97,7 +97,7 @@
- type: mistral
# docs:
# - https://docs.mistral.ai/platform/endpoints/
# - https://docs.mistral.ai/getting-started/models/
# - https://mistral.ai/technology/#pricing
# - https://docs.mistral.ai/api/
models:
@ -225,33 +225,98 @@
models:
- name: gemini-1.0-pro
max_input_tokens: 24568
max_output_tokens: 24568
max_output_tokens?: 8193
input_price: 0.125
output_price: 0.375
- name: gemini-1.0-pro-vision
max_input_tokens: 14336
max_output_tokens: 14336
max_output_tokens?: 2049
input_price: 0.125
output_price: 0.375
supports_vision: true
- name: gemini-1.5-pro-preview-0409
max_input_tokens: 1000000
max_output_tokens: 1000000
max_output_tokens?: 8193
input_price: 2.5
output_price: 7.5
supports_vision: true
- name: claude-3-opus@20240229
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 15
output_price: 75
supports_vision: true
- name: claude-3-sonnet@20240229
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 3
output_price: 15
supports_vision: true
- name: claude-3-haiku@20240307
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 0.25
output_price: 1.25
supports_vision: true
- type: bedrock
# docs:
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
# - https://aws.amazon.com/bedrock/pricing/
models:
- name: anthropic.claude-3-opus-20240229-v1:0
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 15
output_price: 75
supports_vision: true
- name: anthropic.claude-3-sonnet-20240229-v1:0
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 3
output_price: 15
supports_vision: true
- name: anthropic.claude-3-haiku-20240307-v1:0
max_input_tokens: 200000
max_output_tokens: 4096
input_price: 0.25
output_price: 1.25
supports_vision: true
- name: meta.llama2-13b-chat-v1
max_input_tokens: 4096
max_output_tokens: 2048
input_price: 0.75
output_price: 1
- name: meta.llama2-70b-chat-v1
max_input_tokens: 4096
max_output_tokens: 2048
input_price: 1.95
output_price: 2.56
- name: meta.llama3-8b-instruct-v1:0
max_input_tokens: 8192
max_output_tokens: 4096
input_price: 0.4
output_price: 0.6
- name: meta.llama3-70b-instruct-v1:0
max_input_tokens: 8192
max_output_tokens: 4096
input_price: 2.65
output_price: 3.5
- name: mistral.mistral-7b-instruct-v0:2
max_input_tokens: 32000
max_output_tokens: 8192
input_price: 0.15
output_price: 0.2
- name: mistral.mixtral-8x7b-instruct-v0:1
max_input_tokens: 32000
max_output_tokens: 4096
input_price: 0.45
output_price: 0.7
- name: mistral.mistral-large-2402-v1:0
max_input_tokens: 32000
max_output_tokens: 8192
input_price: 8
output_price: 2.4
- type: ernie
# docs:

@ -0,0 +1,443 @@
use super::claude::claude_build_body;
use super::{
catch_error, generate_prompt, BedrockClient, Client, ExtraConfig, Model, ModelConfig,
PromptFormat, PromptType, ReplyHandler, SendData, LLAMA2_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use base64::{engine::general_purpose::STANDARD, Engine};
use bytes::BytesMut;
use chrono::{DateTime, Utc};
use futures_util::StreamExt;
use hmac::{Hmac, Mac};
use indexmap::IndexMap;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client as ReqwestClient, Method, RequestBuilder,
};
use serde::Deserialize;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::str::FromStr;
#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
pub name: Option<String>,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?;
send_message_streaming(builder, handler, &model_category).await
}
}
impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region);
pub const PROMPTS: [PromptType<'static>; 3] = [
(
"access_key_id",
"AWS Access Key ID",
true,
PromptKind::String,
),
(
"secret_access_key",
"AWS Secret Access Key",
true,
PromptKind::String,
),
("region", "AWS Region", true, PromptKind::String),
];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let model_name = &self.model.name;
let uri = if data.stream {
format!("/model/{model_name}/invoke-with-response-stream")
} else {
format!("/model/{model_name}/invoke")
};
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let headers = IndexMap::new();
let mut body = build_body(data, &self.model, model_category)?;
self.model.merge_extra_fields(&mut body);
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
}
async fn send_message(builder: RequestBuilder, model_category: &ModelCategory) -> Result<String> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if status != 200 {
catch_error(&data, status.as_u16())?;
}
let output = match model_category {
ModelCategory::Anthropic => data["content"][0]["text"].as_str(),
ModelCategory::MetaLlama2 | ModelCategory::MetaLlama3 => data["generation"].as_str(),
ModelCategory::Mistral => data["outputs"][0]["text"].as_str(),
};
let output = output.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(output.to_string())
}
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
model_category: &ModelCategory,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if status != 200 {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
bail!("Invalid response data: {data}");
}
let mut stream = res.bytes_stream();
let mut buffer = BytesMut::new();
let mut decoder = MessageFrameDecoder::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buffer.extend_from_slice(&chunk);
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
let response_headers = parse_response_headers(&message)?;
let message_type = response_headers.message_type.as_str();
let smithy_type = response_headers.smithy_type.as_str();
match (message_type, smithy_type) {
("event", "chunk") => {
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
})?;
debug!("bedrock chunk: {data}");
match model_category {
ModelCategory::Anthropic => {
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
}
}
}
}
ModelCategory::MetaLlama2 | ModelCategory::MetaLlama3 => {
if let Some(text) = data["generation"].as_str() {
handler.text(text)?;
}
}
ModelCategory::Mistral => {
if let Some(text) = data["outputs"][0]["text"].as_str() {
handler.text(text)?;
}
}
}
}
("exception", _) => {
let payload = STANDARD.decode(message.payload())?;
let data = String::from_utf8_lossy(&payload);
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
}
_ => {
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
}
}
}
}
Ok(())
}
fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result<Value> {
match model_category {
ModelCategory::Anthropic => {
let mut body = claude_build_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "bedrock-2023-05-31".into();
Ok(body)
}
ModelCategory::MetaLlama2 => meta_llama_build_body(data, model, LLAMA2_PROMPT_FORMAT),
ModelCategory::MetaLlama3 => meta_llama_build_body(data, model, LLAMA3_PROMPT_FORMAT),
ModelCategory::Mistral => mistral_build_body(data, model),
}
}
fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream: _,
} = data;
let prompt = generate_prompt(&messages, pt)?;
let mut body = json!({ "prompt": prompt });
if let Some(v) = model.max_output_tokens {
body["max_gen_len"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
Ok(body)
}
fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream: _,
} = data;
let prompt = generate_prompt(&messages, LLAMA2_PROMPT_FORMAT)?;
let mut body = json!({ "prompt": prompt });
if let Some(v) = model.max_output_tokens {
body["max_tokens"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
Ok(body)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Anthropic,
MetaLlama2,
MetaLlama3,
Mistral,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("anthropic.") {
Ok(ModelCategory::Anthropic)
} else if s.starts_with("meta.llama2") {
Ok(ModelCategory::MetaLlama2)
} else if s.starts_with("meta.llama3") {
Ok(ModelCategory::MetaLlama3)
} else if s.starts_with("mistral") {
Ok(ModelCategory::Mistral)
} else {
unsupported_model!(s)
}
}
}
#[derive(Debug)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
region: String,
}
#[derive(Debug)]
struct AwsRequest {
method: Method,
host: String,
service: String,
uri: String,
querystring: String,
headers: IndexMap<String, String>,
body: String,
}
fn aws_fetch(
client: &ReqwestClient,
credentials: &AwsCredentials,
request: AwsRequest,
) -> Result<RequestBuilder> {
let AwsRequest {
method,
host,
service,
uri,
querystring,
mut headers,
body,
} = request;
let region = &credentials.region;
let endpoint = format!("https://{}{}", host, uri);
let now: DateTime<Utc> = Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = amz_date[0..8].to_string();
headers.insert("host".into(), host.clone());
headers.insert("x-amz-date".into(), amz_date.clone());
let canonical_headers = headers
.iter()
.map(|(key, value)| format!("{}:{}\n", key, value))
.collect::<Vec<_>>()
.join("");
let signed_headers = headers
.iter()
.map(|(key, _)| key.as_str())
.collect::<Vec<_>>()
.join(";");
let payload_hash = sha256(&body);
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
encode_uri(&uri),
querystring,
canonical_headers,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, region, service);
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm,
amz_date,
credential_scope,
sha256(&canonical_request)
);
let signing_key = gen_signing_key(
&credentials.secret_access_key,
&date_stamp,
region,
&service,
);
let signature = sign(&signing_key, &string_to_sign);
let signature = hex_encode(&signature);
let authorization_header = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
);
headers.insert("authorization".into(), authorization_header);
let mut req_headers = HeaderMap::new();
for (k, v) in &headers {
req_headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
}
debug!("Bedrock Request: {endpoint} {body}");
let requst_builder = client
.request(method, endpoint)
.headers(req_headers)
.body(body);
Ok(requst_builder)
}
fn sha256(data: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(data.as_bytes());
format!("{:x}", hasher.finalize())
}
fn sign(key: &[u8], msg: &str) -> Vec<u8> {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
mac.update(msg.as_bytes());
mac.finalize().into_bytes().to_vec()
}
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
let k_date = sign(format!("AWS4{}", key).as_bytes(), date_stamp);
let k_region = sign(&k_date, region);
let k_service = sign(&k_region, service);
sign(&k_service, "aws4_request")
}
fn hex_encode(bytes: &[u8]) -> String {
bytes
.iter()
.fold(String::new(), |acc, b| acc + &format!("{:02x}", b))
}
fn encode_uri(uri: &str) -> String {
uri.split('/')
.map(|v| urlencoding::encode(v))
.collect::<Vec<_>>()
.join("/")
}
fn decode_chunk(data: &[u8]) -> Option<Value> {
let data = serde_json::from_slice::<Value>(data).ok()?;
let data = data["bytes"].as_str()?;
let data = STANDARD.decode(data).ok()?;
serde_json::from_slice(&data).ok()
}

@ -2,11 +2,13 @@
mod common;
mod message;
mod model;
mod prompt_format;
mod reply_handler;
pub use common::*;
pub use message::*;
pub use model::*;
pub use prompt_format::*;
pub use reply_handler::*;
register_client!(
@ -31,6 +33,7 @@ register_client!(
AzureOpenAIClient
),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(moonshot, "moonshot", MoonshotConfig, MoonshotClient),

@ -0,0 +1,81 @@
use super::message::*;
pub struct PromptFormat<'a> {
pub bos_token: &'a str,
pub system_pre_message: &'a str,
pub system_post_message: &'a str,
pub user_pre_message: &'a str,
pub user_post_message: &'a str,
pub assistant_pre_message: &'a str,
pub assistant_post_message: &'a str,
}
pub const LLAMA2_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
bos_token: "<s>",
system_pre_message: "[INST] <<SYS>>",
system_post_message: "<</SYS>> [/INST]",
user_pre_message: "[INST]",
user_post_message: "[/INST]",
assistant_pre_message: "",
assistant_post_message: "</s>",
};
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
bos_token: "<|begin_of_text|>",
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
system_post_message: "<|eot_id|>",
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
user_post_message: "<|eot_id|>",
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_post_message: "<|eot_id|>",
};
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result<String> {
let PromptFormat {
bos_token,
system_pre_message,
system_post_message,
user_pre_message,
user_post_message,
assistant_pre_message,
assistant_post_message,
} = format;
let mut prompt = bos_token.to_string();
let mut image_urls = vec![];
for message in messages {
let role = &message.role;
let content = match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
match item {
MessageContentPart::Text { text } => parts.push(text.clone()),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
image_urls.push(url.clone());
}
}
}
parts.join("\n\n")
}
};
match role {
MessageRole::System => prompt.push_str(&format!(
"{system_pre_message}{content}{system_post_message}"
)),
MessageRole::Assistant => prompt.push_str(&format!(
"{assistant_pre_message}{content}{assistant_post_message}"
)),
MessageRole::User => {
prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}"))
}
}
}
if !image_urls.is_empty() {
anyhow::bail!("The model does not support images: {:?}", image_urls);
}
prompt.push_str(assistant_pre_message);
Ok(prompt)
}
Loading…
Cancel
Save