refactor: improve system message handling (#634)

pull/635/head
sigoden 2 weeks ago committed by GitHub
parent 3826d808d8
commit 52a847743e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,7 @@ model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter top_p: null # Set default top-p parameter
# ---- apperence ---- # ---- apperence ----
highlight: true # Controls syntax highlighting highlight: true # Controls syntax highlighting
light_theme: false # Activates a light color theme when true. ENV: AICHAT_LIGHT_THEME light_theme: false # Activates a light color theme when true. ENV: AICHAT_LIGHT_THEME
# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt for more detils # Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt for more detils
@ -12,7 +12,7 @@ left_prompt:
right_prompt: right_prompt:
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
# ---- behavior ---- # ---- behavior ----
save: true # Indicates whether to persist the message save: true # Indicates whether to persist the message
wrap: no # Controls text wrapping (no, auto, <max-width>) wrap: no # Controls text wrapping (no, auto, <max-width>)
wrap_code: false # Enables or disables wrapping of code blocks wrap_code: false # Enables or disables wrapping of code blocks
@ -61,7 +61,7 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance score for re
rag_template: | rag_template: |
Use the following context as your learned knowledge, inside <context></context> XML tags. Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
__CONTEXT__ __CONTEXT__
</context> </context>
When answer to user: When answer to user:
@ -73,7 +73,7 @@ rag_template: |
Given the context information, answer the query. Given the context information, answer the query.
Query: __INPUT__ Query: __INPUT__
# ---- clients ---- # ---- client ----
clients: clients:
# All clients have the following configuration: # All clients have the following configuration:
# - type: xxxx # - type: xxxx

@ -507,16 +507,15 @@
input_price: 1.68 input_price: 1.68
output_price: 1.68 output_price: 1.68
supports_function_calling: true supports_function_calling: true
- name: ernie-3.5-128k
max_input_tokens: 8192
input_price: 6.72
output_price: 13.44
supports_function_calling: true
- name: ernie-speed-128k - name: ernie-speed-128k
max_input_tokens: 128000 max_input_tokens: 128000
input_price: 0 input_price: 0
output_price: 0 output_price: 0
- name: ernie-lite-8k
max_input_tokens: 8192
max_output_tokens: 2048
require_max_tokens: true
input_price: 0
output_price: 0
- name: embedding-v1 - name: embedding-v1
type: embedding type: embedding
max_input_tokens: 384 max_input_tokens: 384

@ -238,7 +238,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
stream, stream,
} = data; } = data;
patch_system_message(&mut messages); let system_message = extract_system_message(&mut messages);
let messages: Vec<Value> = messages let messages: Vec<Value> = messages
.into_iter() .into_iter()
@ -269,6 +269,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
"messages": messages, "messages": messages,
}); });
if let Some(v) = system_message {
body["system"] = v.into();
}
if let Some(v) = model.max_tokens_param() { if let Some(v) = model.max_tokens_param() {
body["max_output_tokens"] = v.into(); body["max_output_tokens"] = v.into();
} }

@ -21,6 +21,30 @@ impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self { pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content } Self { role, content }
} }
pub fn merge_system(&mut self, system: &str) {
match &mut self.content {
MessageContent::Text(text) => {
self.content = MessageContent::Array(vec![
MessageContentPart::Text {
text: system.to_string(),
},
MessageContentPart::Text {
text: text.to_string(),
},
]);
}
MessageContent::Array(list) => {
list.insert(
0,
MessageContentPart::Text {
text: system.to_string(),
},
);
}
_ => {}
}
}
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
@ -124,12 +148,10 @@ pub struct ImageUrl {
pub fn patch_system_message(messages: &mut Vec<Message>) { pub fn patch_system_message(messages: &mut Vec<Message>) {
if messages[0].role.is_system() { if messages[0].role.is_system() {
let system_message = messages.remove(0); let system_message = messages.remove(0);
if let (Some(message), MessageContent::Text(system_text)) = if let (Some(message), MessageContent::Text(system)) =
(messages.get_mut(0), system_message.content) (messages.get_mut(0), system_message.content)
{ {
if let MessageContent::Text(text) = message.content.clone() { message.merge_system(&system);
message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text))
}
} }
} }
} }

@ -262,7 +262,12 @@ pub fn gemini_build_chat_completions_body(
stream: _, stream: _,
} = data; } = data;
patch_system_message(&mut messages); let system_message = if model.name().starts_with("gemini-1.5-") {
extract_system_message(&mut messages)
} else {
patch_system_message(&mut messages);
None
};
let mut network_image_urls = vec![]; let mut network_image_urls = vec![];
let contents: Vec<Value> = messages let contents: Vec<Value> = messages
@ -333,6 +338,10 @@ pub fn gemini_build_chat_completions_body(
let mut body = json!({ "contents": contents, "generationConfig": {} }); let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(v) = system_message {
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
}
if let Some(v) = model.max_tokens_param() { if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into(); body["generationConfig"]["maxOutputTokens"] = v.into();
} }

Loading…
Cancel
Save