From 52a847743efccb72bcd031cb1d419c84ed9d9afa Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 23 Jun 2024 06:00:58 +0800 Subject: [PATCH] refactor: improve system message handling (#634) --- config.example.yaml | 8 ++++---- models.yaml | 11 +++++------ src/client/ernie.rs | 6 +++++- src/client/message.rs | 30 ++++++++++++++++++++++++++---- src/client/vertexai.rs | 11 ++++++++++- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index a152fcc..8425cc7 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -3,7 +3,7 @@ model: openai:gpt-3.5-turbo # Specify the language model to use temperature: null # Set default temperature parameter top_p: null # Set default top-p parameter -# ---- apperence ---- +# ---- apperence ---- highlight: true # Controls syntax highlighting light_theme: false # Activates a light color theme when true. ENV: AICHAT_LIGHT_THEME # Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt for more detils @@ -12,7 +12,7 @@ left_prompt: right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' -# ---- behavior ---- +# ---- behavior ---- save: true # Indicates whether to persist the message wrap: no # Controls text wrapping (no, auto, ) wrap_code: false # Enables or disables wrapping of code blocks @@ -61,7 +61,7 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance score for re rag_template: | Use the following context as your learned knowledge, inside XML tags. - __CONTEXT__ + __CONTEXT__ When answer to user: @@ -73,7 +73,7 @@ rag_template: | Given the context information, answer the query. Query: __INPUT__ -# ---- clients ---- +# ---- client ---- clients: # All clients have the following configuration: # - type: xxxx diff --git a/models.yaml b/models.yaml index 10c6d50..246c607 100644 --- a/models.yaml +++ b/models.yaml @@ -507,16 +507,15 @@ input_price: 1.68 output_price: 1.68 supports_function_calling: true + - name: ernie-3.5-128k + max_input_tokens: 8192 + input_price: 6.72 + output_price: 13.44 + supports_function_calling: true - name: ernie-speed-128k max_input_tokens: 128000 input_price: 0 output_price: 0 - - name: ernie-lite-8k - max_input_tokens: 8192 - max_output_tokens: 2048 - require_max_tokens: true - input_price: 0 - output_price: 0 - name: embedding-v1 type: embedding max_input_tokens: 384 diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 21362c1..428d263 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -238,7 +238,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu stream, } = data; - patch_system_message(&mut messages); + let system_message = extract_system_message(&mut messages); let messages: Vec = messages .into_iter() @@ -269,6 +269,10 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu "messages": messages, }); + if let Some(v) = system_message { + body["system"] = v.into(); + } + if let Some(v) = model.max_tokens_param() { body["max_output_tokens"] = v.into(); } diff --git a/src/client/message.rs b/src/client/message.rs index d7ba698..77adf47 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -21,6 +21,30 @@ impl Message { pub fn new(role: MessageRole, content: MessageContent) -> Self { Self { role, content } } + + pub fn merge_system(&mut self, system: &str) { + match &mut self.content { + MessageContent::Text(text) => { + self.content = MessageContent::Array(vec![ + MessageContentPart::Text { + text: system.to_string(), + }, + MessageContentPart::Text { + text: text.to_string(), + }, + ]); + } + MessageContent::Array(list) => { + list.insert( + 0, + MessageContentPart::Text { + text: system.to_string(), + }, + ); + } + _ => {} + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] @@ -124,12 +148,10 @@ pub struct ImageUrl { pub fn patch_system_message(messages: &mut Vec) { if messages[0].role.is_system() { let system_message = messages.remove(0); - if let (Some(message), MessageContent::Text(system_text)) = + if let (Some(message), MessageContent::Text(system)) = (messages.get_mut(0), system_message.content) { - if let MessageContent::Text(text) = message.content.clone() { - message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text)) - } + message.merge_system(&system); } } } diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 16910b6..cc7e464 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -262,7 +262,12 @@ pub fn gemini_build_chat_completions_body( stream: _, } = data; - patch_system_message(&mut messages); + let system_message = if model.name().starts_with("gemini-1.5-") { + extract_system_message(&mut messages) + } else { + patch_system_message(&mut messages); + None + }; let mut network_image_urls = vec![]; let contents: Vec = messages @@ -333,6 +338,10 @@ pub fn gemini_build_chat_completions_body( let mut body = json!({ "contents": contents, "generationConfig": {} }); + if let Some(v) = system_message { + body["systemInstruction"] = json!({ "parts": [{"text": v }] }); + } + if let Some(v) = model.max_tokens_param() { body["generationConfig"]["maxOutputTokens"] = v.into(); }