refactor: improve system message handling (#634)

pull/635/head
sigoden 2 weeks ago committed by GitHub
parent 3826d808d8
commit 52a847743e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,7 @@ model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter
# ---- apperence ----
# ---- apperence ----
highlight: true # Controls syntax highlighting
light_theme: false # Activates a light color theme when true. ENV: AICHAT_LIGHT_THEME
# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt for more detils
@ -12,7 +12,7 @@ left_prompt:
right_prompt:
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
# ---- behavior ----
# ---- behavior ----
save: true # Indicates whether to persist the message
wrap: no # Controls text wrapping (no, auto, <max-width>)
wrap_code: false # Enables or disables wrapping of code blocks
@ -61,7 +61,7 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance score for re
rag_template: |
Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
__CONTEXT__
__CONTEXT__
</context>
When answer to user:
@ -73,7 +73,7 @@ rag_template: |
Given the context information, answer the query.
Query: __INPUT__
# ---- clients ----
# ---- client ----
clients:
# All clients have the following configuration:
# - type: xxxx

@ -507,16 +507,15 @@
input_price: 1.68
output_price: 1.68
supports_function_calling: true
- name: ernie-3.5-128k
max_input_tokens: 8192
input_price: 6.72
output_price: 13.44
supports_function_calling: true
- name: ernie-speed-128k
max_input_tokens: 128000
input_price: 0
output_price: 0
- name: ernie-lite-8k
max_input_tokens: 8192
max_output_tokens: 2048
require_max_tokens: true
input_price: 0
output_price: 0
- name: embedding-v1
type: embedding
max_input_tokens: 384

@ -238,7 +238,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
stream,
} = data;
patch_system_message(&mut messages);
let system_message = extract_system_message(&mut messages);
let messages: Vec<Value> = messages
.into_iter()
@ -269,6 +269,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = v.into();
}
if let Some(v) = model.max_tokens_param() {
body["max_output_tokens"] = v.into();
}

@ -21,6 +21,30 @@ impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content }
}
pub fn merge_system(&mut self, system: &str) {
match &mut self.content {
MessageContent::Text(text) => {
self.content = MessageContent::Array(vec![
MessageContentPart::Text {
text: system.to_string(),
},
MessageContentPart::Text {
text: text.to_string(),
},
]);
}
MessageContent::Array(list) => {
list.insert(
0,
MessageContentPart::Text {
text: system.to_string(),
},
);
}
_ => {}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
@ -124,12 +148,10 @@ pub struct ImageUrl {
pub fn patch_system_message(messages: &mut Vec<Message>) {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
if let (Some(message), MessageContent::Text(system_text)) =
if let (Some(message), MessageContent::Text(system)) =
(messages.get_mut(0), system_message.content)
{
if let MessageContent::Text(text) = message.content.clone() {
message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text))
}
message.merge_system(&system);
}
}
}

@ -262,7 +262,12 @@ pub fn gemini_build_chat_completions_body(
stream: _,
} = data;
patch_system_message(&mut messages);
let system_message = if model.name().starts_with("gemini-1.5-") {
extract_system_message(&mut messages)
} else {
patch_system_message(&mut messages);
None
};
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
@ -333,6 +338,10 @@ pub fn gemini_build_chat_completions_body(
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(v) = system_message {
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
}
if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into();
}

Loading…
Cancel
Save