From f6da06dad9b2a76016209a7d58ad923c2c72f150 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 1 Nov 2023 22:15:55 +0800 Subject: [PATCH] refactor: improve code quanity (#194) - extends ModelInfo for tokens calculating - refactor config/session.rs, improve export, render, getter/setter - modify main.rs, allow --model override session.model --- Cargo.toml | 4 +- README.md | 12 ++-- src/client/azure_openai.rs | 4 +- src/client/localai.rs | 4 +- src/client/openai.rs | 6 +- src/config/message.rs | 19 +++--- src/config/mod.rs | 64 ++++++++++---------- src/config/model_info.rs | 64 ++++++++++++++++++-- src/config/session.rs | 120 +++++++++++++++++++++++++++---------- src/main.rs | 8 +-- src/render/repl.rs | 2 +- src/repl/prompt.rs | 2 +- 12 files changed, 209 insertions(+), 100 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 42aece0..8400b51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["chatgpt", "localai", "gpt", "repl"] [dependencies] anyhow = "1.0.69" bytes = "1.4.0" -clap = { version = "4.1.8", features = ["derive", "string"] } +clap = { version = "4.1.8", features = ["derive"] } dirs = "5.0.0" futures-util = "0.3.26" inquire = "0.6.2" @@ -43,7 +43,7 @@ reqwest-eventsource = "0.5.0" [dependencies.reqwest] version = "0.11.14" -features = ["json", "stream", "socks", "rustls-tls", "rustls-tls-native-roots"] +features = ["json", "socks", "rustls-tls", "rustls-tls-native-roots"] default-features = false [dependencies.syntect] diff --git a/README.md b/README.md index 04142e6..38db895 100644 --- a/README.md +++ b/README.md @@ -235,17 +235,21 @@ You should run aichat with `-s/--session` or use the `.session` command to start ``` 〉.session -temp)1 to 5, odd only 4089 + +temp)1 to 5, odd only 0 1, 3, 5 -temp)to 7 4070 +temp)to 7 19(0.46%) 1, 3, 5, 7 -temp).exit session +temp).exit session 42(1.03%) +? Save session? (y/N) -〉 ``` +The prompt on the right side is about the current usage of tokens and the proportion of tokens used, +compared to the maximum number of tokens allowed by the model. + ### `.set` - modify the configuration temporarily diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index d0fa31f..3a48145 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,4 +1,4 @@ -use super::openai::openai_build_body; +use super::openai::{openai_build_body, openai_tokens_formula}; use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; use anyhow::{anyhow, Result}; @@ -46,7 +46,7 @@ impl AzureOpenAIClient { local_config .models .iter() - .map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index)) + .map(|v| openai_tokens_formula(ModelInfo::new(index, client, &v.name).set_max_tokens(v.max_tokens))) .collect() } diff --git a/src/client/localai.rs b/src/client/localai.rs index 291fc96..026ae27 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,4 +1,4 @@ -use super::openai::openai_build_body; +use super::openai::{openai_build_body, openai_tokens_formula}; use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; use anyhow::Result; @@ -45,7 +45,7 @@ impl LocalAIClient { local_config .models .iter() - .map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index)) + .map(|v| openai_tokens_formula(ModelInfo::new(index, client, &v.name).set_max_tokens(v.max_tokens))) .collect() } diff --git a/src/client/openai.rs b/src/client/openai.rs index b554f35..387baf0 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -38,7 +38,7 @@ impl OpenAIClient { let client = Self::name(local_config); MODELS .into_iter() - .map(|(name, max_tokens)| ModelInfo::new(client, name, Some(max_tokens), index)) + .map(|(name, max_tokens)| openai_tokens_formula(ModelInfo::new(index, client, name).set_max_tokens(Some(max_tokens)))) .collect() } @@ -135,3 +135,7 @@ pub fn openai_build_body(data: SendData, model: String) -> Value { } body } + +pub fn openai_tokens_formula(model: ModelInfo) -> ModelInfo { + model.set_tokens_formula(5, 2) +} diff --git a/src/config/message.rs b/src/config/message.rs index 5882337..55b2663 100644 --- a/src/config/message.rs +++ b/src/config/message.rs @@ -1,5 +1,3 @@ -use crate::utils::count_tokens; - use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Deserialize, Serialize)] @@ -25,22 +23,19 @@ pub enum MessageRole { User, } +#[allow(dead_code)] impl MessageRole { - #[allow(dead_code)] pub fn is_system(&self) -> bool { matches!(self, MessageRole::System) } -} -pub fn num_tokens_from_messages(messages: &[Message]) -> usize { - let mut num_tokens = 0; - for message in messages.iter() { - num_tokens += 4; - num_tokens += count_tokens(&message.content); - num_tokens += 1; // role always take 1 token + pub fn is_user(&self) -> bool { + matches!(self, MessageRole::User) + } + + pub fn is_assistant(&self) -> bool { + matches!(self, MessageRole::Assistant) } - num_tokens += 2; - num_tokens } #[cfg(test)] diff --git a/src/config/mod.rs b/src/config/mod.rs index 18f1731..eff9bcf 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -12,7 +12,6 @@ use crate::client::{ all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient, SendData, }; -use crate::config::message::num_tokens_from_messages; use crate::render::RenderOptions; use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err}; @@ -274,7 +273,7 @@ impl Config { pub fn set_temperature(&mut self, value: Option) -> Result<()> { self.temperature = value; if let Some(session) = self.session.as_mut() { - session.temperature = value; + session.set_temperature(value); } Ok(()) } @@ -298,13 +297,6 @@ impl Config { let message = Message::new(content); vec![message] }; - let tokens = num_tokens_from_messages(&messages); - if let Some(max_tokens) = self.model_info.max_tokens { - if tokens >= max_tokens { - bail!("Exceed max tokens limit") - } - } - Ok(messages) } @@ -326,7 +318,7 @@ impl Config { let models = all_models(self); let mut model_info = None; if value.contains(':') { - if let Some(model) = models.iter().find(|v| v.stringify() == value) { + if let Some(model) = models.iter().find(|v| v.full_name() == value) { model_info = Some(model.clone()); } } else if let Some(model) = models.iter().find(|v| v.client == value) { @@ -336,7 +328,7 @@ impl Config { None => bail!("Unknown model '{}'", value), Some(model_info) => { if let Some(session) = self.session.as_mut() { - session.set_model(&model_info.stringify())?; + session.set_model(model_info.clone())?; } self.model_info = model_info; Ok(()) @@ -361,7 +353,7 @@ impl Config { ("roles_file", path_info(&Self::roles_file()?)), ("messages_file", path_info(&Self::messages_file()?)), ("sessions_dir", path_info(&Self::sessions_dir()?)), - ("model", self.model_info.stringify()), + ("model", self.model_info.full_name()), ("temperature", temperature), ("save", self.save.to_string()), ("highlight", self.highlight.to_string()), @@ -389,7 +381,7 @@ impl Config { completion.extend( all_models(self) .iter() - .map(|v| format!(".model {}", v.stringify())), + .map(|v| format!(".model {}", v.full_name())), ); let sessions = self.list_sessions().unwrap_or_default(); completion.extend(sessions.iter().map(|v| format!(".session {}", v))); @@ -444,7 +436,7 @@ impl Config { } self.session = Some(Session::new( TEMP_SESSION_NAME, - &self.model_info.stringify(), + self.model_info.clone(), self.role.clone(), )); } @@ -453,13 +445,13 @@ impl Config { if !session_path.exists() { self.session = Some(Session::new( name, - &self.model_info.stringify(), + self.model_info.clone(), self.role.clone(), )); } else { let session = Session::load(name, &session_path)?; - let model = session.model.clone(); - self.temperature = session.temperature; + let model = session.model().to_string(); + self.temperature = session.temperature(); self.session = Some(session); self.set_model(&model)?; } @@ -472,7 +464,8 @@ impl Config { "Start a session that incorporates the last question and answer?", ) .with_default(false) - .prompt()?; + .prompt() + .map_err(prompt_op_err)?; if ans { session.add_message(input, output)?; } @@ -487,13 +480,19 @@ impl Config { self.last_message = None; self.temperature = self.default_temperature; if session.should_save() { - let ans = Confirm::new("Save session?").with_default(true).prompt()?; + let ans = Confirm::new("Save session?") + .with_default(false) + .prompt() + .map_err(prompt_op_err)?; if !ans { return Ok(()); } - let mut name = session.name.clone(); + let mut name = session.name().to_string(); if session.is_temp() { - name = Text::new("Session name:").with_default(&name).prompt()?; + name = Text::new("Session name:") + .with_default(&name) + .prompt() + .map_err(prompt_op_err)?; } let session_path = Self::session_file(&name)?; let sessions_dir = session_path.parent().ok_or_else(|| { @@ -558,17 +557,13 @@ impl Config { pub fn render_prompt_right(&self) -> String { if let Some(session) = &self.session { - let tokens = session.tokens; - // 10000(%32) - match self.model_info.max_tokens { - Some(max_tokens) => { - let ratio = tokens as f32 / max_tokens as f32; - let percent = ratio * 100.0; - let percent = (percent * 100.0).round() / 100.0; - format!("{tokens}({percent}%)") - } - None => format!("{tokens}"), - } + let (tokens, percent) = session.tokens_and_percent(); + let percent = if percent == 0.0 { + String::new() + } else { + format!("({percent}%)") + }; + format!("{tokens}{percent}") } else { String::new() } @@ -576,6 +571,7 @@ impl Config { pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result { let messages = self.build_messages(content)?; + self.model_info.max_tokens_limit(&messages)?; Ok(SendData { messages, temperature: self.get_temperature(), @@ -586,7 +582,7 @@ impl Config { pub fn maybe_print_send_tokens(&self, input: &str) { if self.dry_run { if let Ok(messages) = self.build_messages(input) { - let tokens = num_tokens_from_messages(&messages); + let tokens = self.model_info.totatl_tokens(&messages); println!(">>> This message consumes {tokens} tokens. <<<"); } } @@ -642,7 +638,7 @@ impl Config { bail!("No available model"); } - models[0].stringify() + models[0].full_name() } }; self.set_model(&model)?; diff --git a/src/config/model_info.rs b/src/config/model_info.rs index 1f8d6f0..fa51b91 100644 --- a/src/config/model_info.rs +++ b/src/config/model_info.rs @@ -1,27 +1,79 @@ +use super::Message; + +use crate::utils::count_tokens; + +use anyhow::{bail, Result}; + #[derive(Debug, Clone)] pub struct ModelInfo { pub client: String, pub name: String, - pub max_tokens: Option, pub index: usize, + pub max_tokens: Option, + pub per_message_tokens: usize, + pub bias_tokens: usize, } impl Default for ModelInfo { fn default() -> Self { - ModelInfo::new("", "", None, 0) + ModelInfo::new(0, "", "") } } impl ModelInfo { - pub fn new(client: &str, name: &str, max_tokens: Option, index: usize) -> Self { + pub fn new(index: usize, client: &str, name: &str) -> Self { Self { + index, client: client.into(), name: name.into(), - max_tokens, - index, + max_tokens: None, + per_message_tokens: 0, + bias_tokens: 0, } } - pub fn stringify(&self) -> String { + + pub fn set_max_tokens(mut self, max_tokens: Option) -> Self { + match max_tokens { + None | Some(0) => self.max_tokens = None, + _ => self.max_tokens = max_tokens, + } + self + } + + pub fn set_tokens_formula(mut self, per_message_token: usize, bias_tokens: usize) -> Self { + self.per_message_tokens = per_message_token; + self.bias_tokens = bias_tokens; + self + } + + pub fn full_name(&self) -> String { format!("{}:{}", self.client, self.name) } + + pub fn messages_tokens(&self, messages: &[Message]) -> usize { + messages.iter().map(|v| count_tokens(&v.content)).sum() + } + + pub fn totatl_tokens(&self, messages: &[Message]) -> usize { + if messages.is_empty() { + return 0; + } + let num_messages = messages.len(); + let message_tokens = self.messages_tokens(messages); + if messages[num_messages - 1].role.is_user() { + num_messages * self.per_message_tokens + message_tokens + } else { + (num_messages - 1) * self.per_message_tokens + message_tokens + } + } + + pub fn max_tokens_limit(&self, messages: &[Message]) -> Result<()> { + let total_tokens = self.totatl_tokens(messages) + self.bias_tokens; + if let Some(max_tokens) = self.max_tokens { + if total_tokens >= max_tokens { + bail!("Exceed max tokens limit") + } + } + Ok(()) + } } diff --git a/src/config/session.rs b/src/config/session.rs index 4cc172d..d07e4cf 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -1,45 +1,47 @@ -use super::message::{num_tokens_from_messages, Message, MessageRole}; +use super::message::{Message, MessageRole}; use super::role::Role; +use super::ModelInfo; use crate::render::MarkdownRender; use anyhow::{bail, Context, Result}; use serde::{Deserialize, Serialize}; +use serde_json::json; use std::fs::{self, read_to_string}; use std::path::Path; pub const TEMP_SESSION_NAME: &str = "temp"; -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct Session { + model: String, + temperature: Option, + messages: Vec, + #[serde(skip)] + pub name: String, + #[serde(skip)] pub path: Option, - pub model: String, - pub tokens: usize, - pub temperature: Option, - pub messages: Vec, #[serde(skip)] pub dirty: bool, #[serde(skip)] pub role: Option, #[serde(skip)] - pub name: String, + pub model_info: ModelInfo, } impl Session { - pub fn new(name: &str, model: &str, role: Option) -> Self { + pub fn new(name: &str, model_info: ModelInfo, role: Option) -> Self { let temperature = role.as_ref().and_then(|v| v.temperature); - let mut value = Self { - path: None, - model: model.to_string(), + Self { + model: model_info.full_name(), temperature, - tokens: 0, messages: vec![], + name: name.to_string(), + path: None, dirty: false, role, - name: name.to_string(), - }; - value.update_tokens(); - value + model_info, + } } pub fn load(name: &str, path: &Path) -> Result { @@ -54,22 +56,64 @@ impl Session { Ok(session) } + pub fn name(&self) -> &str { + &self.name + } + + pub fn model(&self) -> &str { + &self.model + } + + pub fn temperature(&self) -> Option { + self.temperature + } + + pub fn tokens(&self) -> usize { + self.model_info.totatl_tokens(&self.messages) + } + pub fn export(&self) -> Result { self.guard_save()?; - let output = serde_yaml::to_string(&self) + let (tokens, percent) = self.tokens_and_percent(); + let mut data = json!({ + "path": self.path, + "model": self.model(), + }); + if let Some(temperature) = self.temperature() { + data["temperature"] = temperature.into(); + } + data["total-tokens"] = tokens.into(); + if let Some(max_tokens) = self.model_info.max_tokens { + data["max-tokens"] = max_tokens.into(); + } + if percent != 0.0 { + data["total/max-tokens"] = format!("{}%", percent).into(); + } + data["messages"] = json!(self.messages); + + let output = serde_yaml::to_string(&data) .with_context(|| format!("Unable to show info about session {}", &self.name))?; Ok(output) } pub fn render(&self, render: &mut MarkdownRender) -> Result { + let path = self.path.clone().unwrap_or_else(|| "-".to_string()); + let temperature = self - .temperature + .temperature() .map_or_else(|| String::from("-"), |v| v.to_string()); + + let max_tokens = self + .model_info + .max_tokens + .map(|v| v.to_string()) + .unwrap_or_else(|| '-'.to_string()); + let items = vec![ - ("path", self.path.clone().unwrap_or_else(|| "-".into())), - ("model", self.model.clone()), - ("tokens", self.tokens.to_string()), + ("path", path), + ("model", self.model().to_string()), ("temperature", temperature), + ("max_tokens", max_tokens), ]; let mut lines = vec![]; for (name, value) in items { @@ -94,17 +138,32 @@ impl Session { Ok(output) } + pub fn tokens_and_percent(&self) -> (usize, f32) { + let tokens = self.tokens(); + let max_tokens = self.model_info.max_tokens.unwrap_or_default(); + let percent = if max_tokens == 0 { + 0.0 + } else { + let percent = tokens as f32 / max_tokens as f32 * 100.0; + (percent * 100.0).round() / 100.0 + }; + (tokens, percent) + } + pub fn update_role(&mut self, role: Option) -> Result<()> { self.guard_empty()?; self.temperature = role.as_ref().and_then(|v| v.temperature); self.role = role; - self.update_tokens(); Ok(()) } - pub fn set_model(&mut self, model: &str) -> Result<()> { - self.model = model.to_string(); - self.update_tokens(); + pub fn set_temperature(&mut self, value: Option) { + self.temperature = value; + } + + pub fn set_model(&mut self, model_info: ModelInfo) -> Result<()> { + self.model = model_info.full_name(); + self.model_info = model_info; Ok(()) } @@ -112,7 +171,8 @@ impl Session { if !self.should_save() { return Ok(()); } - self.dirty = false; + self.path = Some(session_path.display().to_string()); + let content = serde_yaml::to_string(&self) .with_context(|| format!("Failed to serde session {}", self.name))?; fs::write(session_path, content).with_context(|| { @@ -122,6 +182,9 @@ impl Session { session_path.display() ) })?; + + self.dirty = false; + Ok(()) } @@ -151,10 +214,6 @@ impl Session { self.messages.is_empty() } - pub fn update_tokens(&mut self) { - self.tokens = num_tokens_from_messages(&self.build_emssages("")); - } - pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { let mut need_add_msg = true; if self.messages.is_empty() { @@ -173,7 +232,6 @@ impl Session { role: MessageRole::Assistant, content: output.to_string(), }); - self.tokens = num_tokens_from_messages(&self.messages); self.dirty = true; Ok(()) } diff --git a/src/main.rs b/src/main.rs index b2985e2..6492b4a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,7 +37,7 @@ fn main() -> Result<()> { } if cli.list_models { for model in all_models(&config.read()) { - println!("{}", model.stringify()); + println!("{}", model.full_name()); } exit(0); } @@ -55,15 +55,15 @@ fn main() -> Result<()> { if cli.dry_run { config.write().dry_run = true; } - if let Some(model) = &cli.model { - config.write().set_model(model)?; - } if let Some(name) = &cli.role { config.write().set_role(name)?; } if let Some(session) = &cli.session { config.write().start_session(session)?; } + if let Some(model) = &cli.model { + config.write().set_model(model)?; + } if cli.no_highlight { config.write().highlight = false; } diff --git a/src/render/repl.rs b/src/render/repl.rs index 845a5e3..7363923 100644 --- a/src/render/repl.rs +++ b/src/render/repl.rs @@ -60,7 +60,7 @@ fn repl_render_stream_inner( } if row + 1 >= clear_rows { - queue!(writer, cursor::MoveTo(0, row - clear_rows))?; + queue!(writer, cursor::MoveTo(0, row.saturating_sub(clear_rows)))?; } else { let scroll_rows = clear_rows - row - 1; queue!( diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index 7bed5e0..7007b53 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -23,7 +23,7 @@ impl ReplPrompt { impl Prompt for ReplPrompt { fn render_prompt_left(&self) -> Cow { if let Some(session) = &self.config.read().session { - Cow::Owned(session.name.clone()) + Cow::Owned(session.name().to_string()) } else if let Some(role) = &self.config.read().role { Cow::Owned(role.name.clone()) } else {