refactor: improve system message handling (#634)

pull/635/head
sigoden 2 weeks ago committed by GitHub
parent 3826d808d8
commit 52a847743e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -73,7 +73,7 @@ rag_template: |
Given the context information, answer the query.
Query: __INPUT__
# ---- clients ----
# ---- client ----
clients:
# All clients have the following configuration:
# - type: xxxx

@ -507,16 +507,15 @@
input_price: 1.68
output_price: 1.68
supports_function_calling: true
- name: ernie-3.5-128k
max_input_tokens: 8192
input_price: 6.72
output_price: 13.44
supports_function_calling: true
- name: ernie-speed-128k
max_input_tokens: 128000
input_price: 0
output_price: 0
- name: ernie-lite-8k
max_input_tokens: 8192
max_output_tokens: 2048
require_max_tokens: true
input_price: 0
output_price: 0
- name: embedding-v1
type: embedding
max_input_tokens: 384

@ -238,7 +238,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
stream,
} = data;
patch_system_message(&mut messages);
let system_message = extract_system_message(&mut messages);
let messages: Vec<Value> = messages
.into_iter()
@ -269,6 +269,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = v.into();
}
if let Some(v) = model.max_tokens_param() {
body["max_output_tokens"] = v.into();
}

@ -21,6 +21,30 @@ impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content }
}
pub fn merge_system(&mut self, system: &str) {
match &mut self.content {
MessageContent::Text(text) => {
self.content = MessageContent::Array(vec![
MessageContentPart::Text {
text: system.to_string(),
},
MessageContentPart::Text {
text: text.to_string(),
},
]);
}
MessageContent::Array(list) => {
list.insert(
0,
MessageContentPart::Text {
text: system.to_string(),
},
);
}
_ => {}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
@ -124,12 +148,10 @@ pub struct ImageUrl {
pub fn patch_system_message(messages: &mut Vec<Message>) {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
if let (Some(message), MessageContent::Text(system_text)) =
if let (Some(message), MessageContent::Text(system)) =
(messages.get_mut(0), system_message.content)
{
if let MessageContent::Text(text) = message.content.clone() {
message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text))
}
message.merge_system(&system);
}
}
}

@ -262,7 +262,12 @@ pub fn gemini_build_chat_completions_body(
stream: _,
} = data;
let system_message = if model.name().starts_with("gemini-1.5-") {
extract_system_message(&mut messages)
} else {
patch_system_message(&mut messages);
None
};
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
@ -333,6 +338,10 @@ pub fn gemini_build_chat_completions_body(
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(v) = system_message {
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
}
if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into();
}

Loading…
Cancel
Save