diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e2ccc71..4fc84b3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,7 +44,7 @@ jobs: run: cargo test --all - name: Clippy - run: cargo clippy --all --all-targets + run: cargo clippy --all --all-targets -- -D warnings - name: Format run: cargo fmt --all --check \ No newline at end of file diff --git a/src/cli.rs b/src/cli.rs index e578bb3..42285e6 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,5 +1,6 @@ use clap::Parser; +#[allow(clippy::struct_excessive_bools, clippy::module_name_repetitions)] #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Cli { diff --git a/src/client.rs b/src/client.rs index d366509..74871a0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,6 +12,7 @@ use tokio::time::sleep; const API_URL: &str = "https://api.openai.com/v1/chat/completions"; +#[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct ChatGptClient { config: SharedConfig, @@ -106,16 +107,15 @@ impl ChatGptClient { let chunk = part?.data; if chunk == "[DONE]" { break; - } else { - let data: Value = serde_json::from_str(&chunk)?; - let text = data["choices"][0]["delta"]["content"] - .as_str() - .unwrap_or_default(); - if text.is_empty() { - continue; - } - handler.text(text)?; } + let data: Value = serde_json::from_str(&chunk)?; + let text = data["choices"][0]["delta"]["content"] + .as_str() + .unwrap_or_default(); + if text.is_empty() { + continue; + } + handler.text(text)?; } Ok(()) diff --git a/src/config/conversation.rs b/src/config/conversation.rs index 8343400..b9a0793 100644 --- a/src/config/conversation.rs +++ b/src/config/conversation.rs @@ -43,11 +43,12 @@ impl Conversation { self.tokens = num_tokens_from_messages(&self.build_emssages("")); } + #[allow(clippy::unnecessary_wraps)] pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { let mut need_add_msg = true; if self.messages.is_empty() { if let Some(role) = self.role.as_ref() { - self.messages.extend(role.build_emssages(input)); + self.messages.extend(role.build_messages(input)); need_add_msg = false; } } @@ -67,15 +68,15 @@ impl Conversation { pub fn echo_messages(&self, content: &str) -> String { let messages = self.build_emssages(content); - serde_yaml::to_string(&messages).unwrap_or("Unable to echo message".into()) + serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) } pub fn build_emssages(&self, content: &str) -> Vec { - let mut messages = self.messages.to_vec(); + let mut messages = self.messages.clone(); let mut need_add_msg = true; if messages.is_empty() { if let Some(role) = self.role.as_ref() { - messages = role.build_emssages(content); + messages = role.build_messages(content); need_add_msg = false; } }; diff --git a/src/config/message.rs b/src/config/message.rs index 2c8a330..3c23634 100644 --- a/src/config/message.rs +++ b/src/config/message.rs @@ -17,6 +17,7 @@ impl Message { } } +#[allow(clippy::module_name_repetitions)] #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum MessageRole { @@ -45,6 +46,6 @@ mod tests { assert_eq!( serde_json::to_string(&Message::new("Hello World")).unwrap(), "{\"role\":\"user\",\"content\":\"Hello World\"}" - ) + ); } } diff --git a/src/config/mod.rs b/src/config/mod.rs index aec9c93..6e68283 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -45,6 +45,7 @@ const SET_COMPLETIONS: [&str; 8] = [ ".set dry_run false", ]; +#[allow(clippy::struct_excessive_bools)] #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { @@ -109,6 +110,7 @@ impl Default for Config { } } +#[allow(clippy::module_name_repetitions)] pub type SharedConfig = Arc>; impl Config { @@ -119,7 +121,7 @@ impl Config { create_config_file(&config_path)?; } let mut config = if api_key.is_some() && !config_path.exists() { - Default::default() + Self::default() } else { Self::load_config(&config_path)? }; @@ -156,13 +158,12 @@ impl Config { pub fn config_dir() -> Result { let env_name = get_env_name("config_dir"); - let path = match env::var_os(env_name) { - Some(v) => PathBuf::from(v), - None => { - let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?; - dir.push(env!("CARGO_CRATE_NAME")); - dir - } + let path = if let Some(v) = env::var_os(env_name) { + PathBuf::from(v) + } else { + let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?; + dir.push(env!("CARGO_CRATE_NAME")); + dir }; Ok(path) } @@ -186,18 +187,17 @@ impl Config { None => { format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",) } + Some(v) if v.is_temp() => { + format!( + "# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n", + v.prompt + ) + } Some(v) => { - if v.is_temp() { - format!( - "# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n", - v.prompt - ) - } else { - format!( - "# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n", - v.name, - ) - } + format!( + "# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n", + v.name, + ) } }; file.write_all(output.as_bytes()) @@ -216,11 +216,10 @@ impl Config { pub fn roles_file() -> Result { let env_name = get_env_name("roles_file"); - if let Ok(value) = env::var(env_name) { - Ok(PathBuf::from(value)) - } else { - Self::local_file(ROLES_FILE_NAME) - } + env::var(env_name).map_or_else( + |_| Self::local_file(ROLES_FILE_NAME), + |value| Ok(PathBuf::from(value)), + ) } pub fn history_file() -> Result { @@ -237,8 +236,8 @@ impl Config { if let Some(conversation) = self.conversation.as_mut() { conversation.update_role(&role)?; } - let output = - serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into()); + let output = serde_yaml::to_string(&role) + .unwrap_or_else(|_| "Unable to echo role details".into()); self.role = Some(role); Ok(output) } @@ -271,6 +270,7 @@ impl Config { } pub fn echo_messages(&self, content: &str) -> String { + #[allow(clippy::option_if_let_else)] if let Some(conversation) = self.conversation.as_ref() { conversation.echo_messages(content) } else if let Some(role) = self.role.as_ref() { @@ -280,7 +280,7 @@ impl Config { } } - pub fn get_connect_timeout(&self) -> Duration { + pub const fn get_connect_timeout(&self) -> Duration { Duration::from_secs(self.connect_timeout as u64) } @@ -289,10 +289,11 @@ impl Config { } pub fn build_messages(&self, content: &str) -> Result> { + #[allow(clippy::option_if_let_else)] let messages = if let Some(conversation) = self.conversation.as_ref() { conversation.build_emssages(content) } else if let Some(role) = self.role.as_ref() { - role.build_emssages(content) + role.build_messages(content) } else { let message = Message::new(content); vec![message] @@ -314,7 +315,7 @@ impl Config { Ok(()) } - pub fn get_reamind_tokens(&self) -> usize { + pub const fn get_reamind_tokens(&self) -> usize { let mut tokens = self.model.1; if let Some(conversation) = self.conversation.as_ref() { tokens = tokens.saturating_sub(conversation.tokens); @@ -330,21 +331,17 @@ impl Config { let proxy = self .proxy .as_ref() - .map(|v| v.to_string()) - .unwrap_or("-".into()); + .map_or_else(|| String::from("-"), std::string::ToString::to_string); let temperature = self .temperature - .map(|v| v.to_string()) - .unwrap_or("-".into()); + .map_or_else(|| String::from("-"), |v| v.to_string()); let (api_key, organization_id) = self.get_api_key(); let api_key = mask_text(&api_key, 3, 4); - let organization_id = organization_id - .map(|v| mask_text(&v, 3, 4)) - .unwrap_or("-".into()); + let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4)); let items = vec![ - ("config_file", file_info(&Config::config_file()?)), - ("roles_file", file_info(&Config::roles_file()?)), - ("messages_file", file_info(&Config::messages_file()?)), + ("config_file", file_info(&Self::config_file()?)), + ("roles_file", file_info(&Self::roles_file()?)), + ("messages_file", file_info(&Self::messages_file()?)), ("api_key", api_key), ("organization_id", organization_id), ("model", self.model.0.to_string()), @@ -371,8 +368,8 @@ impl Config { .map(|v| format!(".role {}", v.name)) .collect(); - completion.extend(SET_COMPLETIONS.map(|v| v.to_string())); - completion.extend(MODELS.map(|(v, _)| format!(".model {}", v))); + completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string)); + completion.extend(MODELS.map(|(v, _)| format!(".model {v}"))); completion } @@ -441,7 +438,7 @@ impl Config { Ok(()) } - pub fn get_render_options(&self) -> (bool, bool) { + pub const fn get_render_options(&self) -> (bool, bool) { (self.highlight, self.light_theme) } @@ -449,13 +446,14 @@ impl Config { if self.dry_run { if let Ok(messages) = self.build_messages(input) { let tokens = num_tokens_from_messages(&messages); - println!(">>> The following message consumes {tokens} tokens.") + println!(">>> The following message consumes {tokens} tokens."); } } } + #[allow(clippy::unused_self)] // TODO: do we need to take self here? it's not used in the fn fn open_message_file(&self) -> Result { - let path = Config::messages_file()?; + let path = Self::messages_file()?; ensure_parent_exists(&path)?; OpenOptions::new() .create(true) @@ -468,7 +466,7 @@ impl Config { let content = read_to_string(config_path) .with_context(|| format!("Failed to load config at {}", config_path.display()))?; - let config: Config = serde_yaml::from_str(&content) + let config: Self = serde_yaml::from_str(&content) .with_context(|| format!("Invalid config at {}", config_path.display()))?; Ok(config) } diff --git a/src/config/role.rs b/src/config/role.rs index 2def03d..32ade08 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -55,7 +55,7 @@ impl Role { } } - pub fn build_emssages(&self, content: &str) -> Vec { + pub fn build_messages(&self, content: &str) -> Vec { if self.embeded() { let content = merge_prompt_content(&self.prompt, content); vec![Message { diff --git a/src/main.rs b/src/main.rs index af772ca..b201fd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,9 +36,9 @@ fn main() -> Result<()> { exit(0); } if cli.list_models { - config::MODELS - .iter() - .for_each(|(name, _)| println!("{}", name)); + for (name, _) in &config::MODELS { + println!("{name}"); + } exit(0); } if cli.dry_run { @@ -76,18 +76,18 @@ fn main() -> Result<()> { if let Some(text) = text { input = format!("{text}\n{input}"); } - start_directive(client, config, &input, no_stream) + start_directive(&client, &config, &input, no_stream) } else { match text { - Some(text) => start_directive(client, config, &text, no_stream), + Some(text) => start_directive(&client, &config, &text, no_stream), None => start_interactive(client, config), } } } fn start_directive( - client: ChatGptClient, - config: SharedConfig, + client: &ChatGptClient, + config: &SharedConfig, input: &str, no_stream: bool, ) -> Result<()> { @@ -113,7 +113,7 @@ fn start_directive( abort_clone.set_ctrlc(); }) .expect("Error setting Ctrl-C handler"); - let output = render_stream(input, &client, config.clone(), false, abort, wg.clone())?; + let output = render_stream(input, client, config, false, abort, wg.clone())?; wg.wait(); output }; diff --git a/src/render/cmd.rs b/src/render/cmd.rs index 62a0a34..bf7a006 100644 --- a/src/render/cmd.rs +++ b/src/render/cmd.rs @@ -6,10 +6,11 @@ use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; use anyhow::Result; use crossbeam::channel::Receiver; +#[allow(clippy::unnecessary_wraps, clippy::module_name_repetitions)] pub fn cmd_render_stream( - rx: Receiver, + rx: &Receiver, light_theme: bool, - abort: SharedAbortSignal, + abort: &SharedAbortSignal, ) -> Result<()> { let mut buffer = String::new(); let mut markdown_render = MarkdownRender::new(light_theme); @@ -25,7 +26,7 @@ pub fn cmd_render_stream( let mut lines: Vec<&str> = text.split('\n').collect(); buffer = lines.pop().unwrap_or_default().to_string(); let output = lines.join("\n"); - print_now!("{}\n", markdown_render.render(&output)) + print_now!("{}\n", markdown_render.render(&output)); } else { buffer = format!("{buffer}{text}"); if !(markdown_render.is_code_block() @@ -36,7 +37,7 @@ pub fn cmd_render_stream( { if let Some((output, remain)) = split_line(&buffer) { print_now!("{}", markdown_render.render_line_stateless(&output)); - buffer = remain + buffer = remain; } } } @@ -74,8 +75,8 @@ fn split_line(line: &str) -> Option<(String, String)> { index += 2; continue; } - do_balance(&mut balance, &chars[index..index + 1]); - index += 1 + do_balance(&mut balance, &chars[index..=index]); + index += 1; } None @@ -101,25 +102,25 @@ impl Kind { fn from_chars(chars: &[char]) -> Option { let kind = match chars.len() { 1 => match chars[0] { - '(' => Kind::ParentheseStart, - ')' => Kind::ParentheseEnd, - '[' => Kind::BracketStart, - ']' => Kind::BracketEnd, - '*' => Kind::Asterisk, - '\'' => Kind::SingleQuota, - '"' => Kind::DoubleQuota, - '~' => Kind::Tilde, - '`' => Kind::Backtick, + '(' => Self::ParentheseStart, + ')' => Self::ParentheseEnd, + '[' => Self::BracketStart, + ']' => Self::BracketEnd, + '*' => Self::Asterisk, + '\'' => Self::SingleQuota, + '"' => Self::DoubleQuota, + '~' => Self::Tilde, + '`' => Self::Backtick, _ => return None, }, 2 if chars[0] == chars[1] => match chars[0] { - '*' => Kind::Asterisk2, - '~' => Kind::Tilde2, + '*' => Self::Asterisk2, + '~' => Self::Tilde2, _ => return None, }, 3 => { if chars == ['`', '`', '`'] { - Kind::Backtick3 + Self::Backtick3 } else { return None; } @@ -131,22 +132,12 @@ impl Kind { } fn do_balance(balance: &mut Vec, chars: &[char]) -> bool { - if let Some(kind) = Kind::from_chars(chars) { + Kind::from_chars(chars).map_or(false, |kind| { let last = balance.last(); match (kind, last) { - (Kind::ParentheseStart | Kind::BracketStart, _) => { - balance.push(kind); - true - } - (Kind::ParentheseEnd, Some(&Kind::ParentheseStart)) => { - balance.pop(); - true - } - (Kind::BracketEnd, Some(&Kind::BracketStart)) => { - balance.pop(); - true - } - (Kind::Asterisk, Some(&Kind::Asterisk)) + (Kind::ParentheseEnd, Some(&Kind::ParentheseStart)) + | (Kind::BracketEnd, Some(&Kind::BracketStart)) + | (Kind::Asterisk, Some(&Kind::Asterisk)) | (Kind::Asterisk2, Some(&Kind::Asterisk2)) | (Kind::SingleQuota, Some(&Kind::SingleQuota)) | (Kind::DoubleQuota, Some(&Kind::DoubleQuota)) @@ -157,22 +148,25 @@ fn do_balance(balance: &mut Vec, chars: &[char]) -> bool { balance.pop(); true } - (Kind::Asterisk, _) - | (Kind::Asterisk2, _) - | (Kind::SingleQuota, _) - | (Kind::DoubleQuota, _) - | (Kind::Tilde, _) - | (Kind::Tilde2, _) - | (Kind::Backtick, _) - | (Kind::Backtick3, _) => { + ( + Kind::ParentheseStart + | Kind::BracketStart + | Kind::Asterisk + | Kind::Asterisk2 + | Kind::SingleQuota + | Kind::DoubleQuota + | Kind::Tilde + | Kind::Tilde2 + | Kind::Backtick + | Kind::Backtick3, + _, + ) => { balance.push(kind); true } _ => false, } - } else { - false - } + }) } #[cfg(test)] diff --git a/src/render/markdown.rs b/src/render/markdown.rs index 5bdca9d..75904af 100644 --- a/src/render/markdown.rs +++ b/src/render/markdown.rs @@ -8,6 +8,7 @@ use syntect::{easy::HighlightLines, parsing::SyntaxReference}; /// Monokai Extended const MD_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin"); const MD_THEME_LIGHT: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin"); +#[allow(clippy::doc_markdown)] /// Comes from https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin"); @@ -20,6 +21,7 @@ lazy_static! { }; } +#[allow(clippy::module_name_repetitions)] pub struct MarkdownRender { syntax_set: SyntaxSet, md_theme: Theme, @@ -67,7 +69,7 @@ impl MarkdownRender { output.unwrap_or_else(|| line.to_string()) } - pub fn is_code_block(&self) -> bool { + pub const fn is_code_block(&self) -> bool { matches!( self.prev_line_type, LineType::CodeBegin | LineType::CodeInner @@ -123,13 +125,14 @@ impl MarkdownRender { } fn render_code_line(&self, line: &str) -> Option { - self.code_syntax - .as_ref() - .map(|syntax| self.render_line_inner(line, syntax)) - .unwrap_or_else(|| Some(format!("{}", line.with(self.code_color)))) + self.code_syntax.as_ref().map_or_else( + || Some(format!("{}", line.with(self.code_color))), + |syntax| self.render_line_inner(line, syntax), + ) } fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> { + #[allow(clippy::option_if_let_else)] if let Some(new_lang) = LANGE_MAPS.get(&lang.to_ascii_lowercase()) { self.syntax_set.find_syntax_by_name(new_lang) } else { @@ -154,17 +157,17 @@ fn as_terminal_escaped(ranges: &[(Style, &str)]) -> String { let fg = blend_fg_color(style.foreground, style.background); let mut text = text.with(convert_color(fg)); if style.font_style.contains(FontStyle::BOLD) { - text = text.bold() + text = text.bold(); } if style.font_style.contains(FontStyle::UNDERLINE) { - text = text.underlined() + text = text.underlined(); } output.push_str(&text.to_string()); } output } -fn convert_color(c: SyntectColor) -> Color { +const fn convert_color(c: SyntectColor) -> Color { Color::Rgb { r: c.r, g: c.g, @@ -176,14 +179,14 @@ fn blend_fg_color(fg: SyntectColor, bg: SyntectColor) -> SyntectColor { if fg.a == 0xff { return fg; } - let ratio = fg.a as u32; - let r = (fg.r as u32 * ratio + bg.r as u32 * (255 - ratio)) / 255; - let g = (fg.g as u32 * ratio + bg.g as u32 * (255 - ratio)) / 255; - let b = (fg.b as u32 * ratio + bg.b as u32 * (255 - ratio)) / 255; + let ratio = u32::from(fg.a); + let r = (u32::from(fg.r) * ratio + u32::from(bg.r) * (255 - ratio)) / 255; + let g = (u32::from(fg.g) * ratio + u32::from(bg.g) * (255 - ratio)) / 255; + let b = (u32::from(fg.b) * ratio + u32::from(bg.b) * (255 - ratio)) / 255; SyntectColor { - r: r as u8, - g: g as u8, - b: b as u8, + r: u8::try_from(r).unwrap_or(u8::MAX), + g: u8::try_from(g).unwrap_or(u8::MAX), + b: u8::try_from(b).unwrap_or(u8::MAX), a: 255, } } @@ -209,8 +212,7 @@ fn get_code_color(theme: &Theme) -> Color { }); scope .and_then(|v| v.style.foreground) - .map(convert_color) - .unwrap_or_else(|| Color::Yellow) + .map_or_else(|| Color::Yellow, convert_color) } #[cfg(test)] diff --git a/src/render/mod.rs b/src/render/mod.rs index 5980302..20d4d0e 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -3,6 +3,7 @@ mod markdown; mod repl; use self::cmd::cmd_render_stream; +#[allow(clippy::module_name_repetitions)] pub use self::markdown::MarkdownRender; use self::repl::repl_render_stream; @@ -16,10 +17,11 @@ use crossbeam::channel::unbounded; use crossbeam::sync::WaitGroup; use std::thread::spawn; +#[allow(clippy::module_name_repetitions)] pub fn render_stream( input: &str, client: &ChatGptClient, - config: SharedConfig, + config: &SharedConfig, repl: bool, abort: SharedAbortSignal, wg: WaitGroup, @@ -30,9 +32,9 @@ pub fn render_stream( let abort_clone = abort.clone(); spawn(move || { let err = if repl { - repl_render_stream(rx, light_theme, abort) + repl_render_stream(&rx, light_theme, &abort) } else { - cmd_render_stream(rx, light_theme, abort) + cmd_render_stream(&rx, light_theme, &abort) }; if let Err(err) = err { let err = format!("{err:?}"); diff --git a/src/render/repl.rs b/src/render/repl.rs index 04dee73..2bdf3b2 100644 --- a/src/render/repl.rs +++ b/src/render/repl.rs @@ -16,10 +16,11 @@ use std::{ }; use unicode_width::UnicodeWidthStr; +#[allow(clippy::module_name_repetitions)] pub fn repl_render_stream( - rx: Receiver, + rx: &Receiver, light_theme: bool, - abort: SharedAbortSignal, + abort: &SharedAbortSignal, ) -> Result<()> { enable_raw_mode()?; let mut stdout = io::stdout(); @@ -32,9 +33,9 @@ pub fn repl_render_stream( } fn repl_render_stream_inner( - rx: Receiver, + rx: &Receiver, light_theme: bool, - abort: SharedAbortSignal, + abort: &SharedAbortSignal, writer: &mut Stdout, ) -> Result<()> { let mut last_tick = Instant::now(); @@ -118,7 +119,8 @@ fn repl_render_stream_inner( } fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> { - let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns; + let buffer_rows = (u16::try_from(buffer.width()).unwrap_or(u16::MAX) + terminal_columns - 1) + / terminal_columns; let (_, row) = cursor::position()?; if buffer_rows == 0 { queue!(writer, cursor::MoveTo(0, row))?; diff --git a/src/repl/abort.rs b/src/repl/abort.rs index f76abb6..5377f43 100644 --- a/src/repl/abort.rs +++ b/src/repl/abort.rs @@ -5,6 +5,7 @@ use std::sync::{ pub type SharedAbortSignal = Arc; +#[allow(clippy::module_name_repetitions)] pub struct AbortSignal { ctrlc: AtomicBool, ctrld: AtomicBool, diff --git a/src/repl/handler.rs b/src/repl/handler.rs index 495413c..2057b8b 100644 --- a/src/repl/handler.rs +++ b/src/repl/handler.rs @@ -24,6 +24,7 @@ pub enum ReplCmd { Copy, } +#[allow(clippy::module_name_repetitions)] pub struct ReplCmdHandler { client: ChatGptClient, config: SharedConfig, @@ -32,6 +33,7 @@ pub struct ReplCmdHandler { } impl ReplCmdHandler { + #[allow(clippy::unnecessary_wraps)] pub fn init( client: ChatGptClient, config: SharedConfig, @@ -58,7 +60,7 @@ impl ReplCmdHandler { let ret = render_stream( &input, &self.client, - self.config.clone(), + &self.config, true, self.abort.clone(), wg.clone(), @@ -113,6 +115,7 @@ impl ReplCmdHandler { } } +#[allow(clippy::module_name_repetitions)] pub struct ReplyStreamHandler { sender: Option>, buffer: String, @@ -154,22 +157,19 @@ impl ReplyStreamHandler { } pub fn done(&mut self) -> Result<()> { - match self.sender.as_ref() { - Some(tx) => { - let ret = tx - .send(ReplyStreamEvent::Done) - .with_context(|| "Failed to send StreamEvent:Done"); - self.safe_ret(ret)?; + if let Some(tx) = self.sender.as_ref() { + let ret = tx + .send(ReplyStreamEvent::Done) + .with_context(|| "Failed to send StreamEvent:Done"); + self.safe_ret(ret)?; + } else { + if !self.buffer.ends_with('\n') { + print_now!("\n"); } - None => { - if !self.buffer.ends_with('\n') { - print_now!("\n") - } - if self.repl { + if self.repl { + print_now!("\n"); + if cfg!(macos) { print_now!("\n"); - if cfg!(macos) { - print_now!("\n") - } } } } diff --git a/src/repl/highlighter.rs b/src/repl/highlighter.rs index 593c307..b1ec023 100644 --- a/src/repl/highlighter.rs +++ b/src/repl/highlighter.rs @@ -5,6 +5,7 @@ use reedline::{Highlighter, StyledText}; const MATCH_COLOR: Color = Color::Green; +#[allow(clippy::module_name_repetitions)] pub struct ReplHighlighter { external_commands: Vec, config: SharedConfig, @@ -12,10 +13,10 @@ pub struct ReplHighlighter { impl ReplHighlighter { /// Construct the default highlighter with a given set of extern commands/keywords to detect and highlight - pub fn new(config: SharedConfig, external_commands: Vec) -> ReplHighlighter { + pub fn new(config: SharedConfig, external_commands: Vec) -> Self { Self { - config, external_commands, + config, } } } @@ -28,9 +29,10 @@ impl Highlighter for ReplHighlighter { } else { Color::White }; - let match_color = match self.config.read().highlight { - true => MATCH_COLOR, - false => color, + let match_color = if self.config.read().highlight { + MATCH_COLOR + } else { + color }; if self @@ -45,7 +47,7 @@ impl Highlighter for ReplHighlighter { .filter(|c| line.contains(*c)) .map(std::ops::Deref::deref) .collect(); - let longest_match = matches.iter().fold("".to_string(), |acc, &item| { + let longest_match = matches.iter().fold(String::new(), |acc, &item| { if item.len() > acc.len() { item.to_string() } else { diff --git a/src/repl/init.rs b/src/repl/init.rs index 2ec30bf..089e855 100644 --- a/src/repl/init.rs +++ b/src/repl/init.rs @@ -24,7 +24,7 @@ impl Repl { .map(|(v, _)| v.to_string()) .collect(); - let completer = Self::create_completer(config.clone(), &commands); + let completer = Self::create_completer(&config, &commands); let highlighter = ReplHighlighter::new(config.clone(), commands); let keybindings = Self::create_keybindings(); let history = Self::create_history()?; @@ -45,7 +45,7 @@ impl Repl { Ok(Self { editor, prompt }) } - fn create_completer(config: SharedConfig, commands: &[String]) -> DefaultCompleter { + fn create_completer(config: &SharedConfig, commands: &[String]) -> DefaultCompleter { let mut completion = commands.to_vec(); completion.extend(config.read().repl_completions()); let mut completer = diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 8f3d184..a728e2e 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -55,7 +55,7 @@ impl Repl { Ok(Signal::Success(line)) => { already_ctrlc = false; abort.reset(); - match self.handle_line(handler.clone(), line) { + match self.handle_line(&handler, &line) { Ok(quit) => { if quit { break; @@ -69,12 +69,11 @@ impl Repl { } Ok(Signal::CtrlC) => { abort.set_ctrlc(); - if !already_ctrlc { - already_ctrlc = true; - print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n"); - } else { + if already_ctrlc { break; } + already_ctrlc = true; + print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n"); } Ok(Signal::CtrlD) => { abort.set_ctrld(); @@ -86,9 +85,9 @@ impl Repl { Ok(()) } - fn handle_line(&mut self, handler: Arc, line: String) -> Result { - let line = clean_multiline_symbols(&line); - match parse_command(&line) { + fn handle_line(&mut self, handler: &Arc, line: &str) -> Result { + let line = clean_multiline_symbols(line); + match parse_command(line.as_ref()) { Some((cmd, args)) => match cmd { ".exit" => { return Ok(true); @@ -176,7 +175,7 @@ Press Ctrl+C to abort readline, Ctrl+D to exit the REPL fn clean_multiline_symbols(line: &str) -> Cow { let trimed_line = line.trim(); match trimed_line.chars().next() { - Some('{') | Some('[') | Some('(') => trimed_line[1..trimed_line.len() - 1].into(), + Some('{' | '[' | '(') => trimed_line[1..trimed_line.len() - 1].into(), _ => Cow::Borrowed(line), } } diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index 8c670f4..c9fa005 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -9,6 +9,7 @@ const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBl const INDICATOR_COLOR: Color = Color::Cyan; const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5); +#[allow(clippy::module_name_repetitions)] #[derive(Clone)] pub struct ReplPrompt { config: SharedConfig, @@ -21,7 +22,7 @@ pub struct ReplPrompt { impl ReplPrompt { pub fn new(config: SharedConfig) -> Self { let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) = - Self::get_colors(config.clone()); + Self::get_colors(&config); Self { config, prompt_color, @@ -32,14 +33,14 @@ impl ReplPrompt { } pub fn sync_config(&mut self) { let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) = - Self::get_colors(self.config.clone()); + Self::get_colors(&self.config); self.prompt_color = prompt_color; self.prompt_multiline_color = prompt_multiline_color; self.indicator_color = indicator_color; self.prompt_right_color = prompt_right_color; } - pub fn get_colors(config: SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) { + pub fn get_colors(config: &SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) { let (highlight, light_theme) = config.read().get_render_options(); if highlight { ( @@ -68,11 +69,11 @@ impl ReplPrompt { impl Prompt for ReplPrompt { fn render_prompt_left(&self) -> Cow { - if let Some(role) = self.config.read().role.as_ref() { - role.name.to_string().into() - } else { - Cow::Borrowed("") - } + self.config + .read() + .role + .as_ref() + .map_or(Cow::Borrowed(""), |role| Cow::Owned(role.name.clone())) } fn render_prompt_right(&self) -> Cow { diff --git a/src/repl/validator.rs b/src/repl/validator.rs index e1f6e99..d829911 100644 --- a/src/repl/validator.rs +++ b/src/repl/validator.rs @@ -1,6 +1,7 @@ use reedline::{ValidationResult, Validator}; /// A default validator which checks for mismatched quotes and brackets +#[allow(clippy::module_name_repetitions)] pub struct ReplValidator; impl Validator for ReplValidator { @@ -30,7 +31,7 @@ fn incomplete_brackets(line: &str) -> bool { None => match c { '{' | '(' => { balance.push(c); - symbol = Some(c) + symbol = Some(c); } _ => {} }, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ec57fed..8b5dae6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -14,7 +14,7 @@ macro_rules! print_now { }; } -pub fn print_now(text: T) { +pub fn print_now(text: &T) { print!("{}", text.to_string()); let _ = stdout().flush(); } diff --git a/src/utils/tiktoken.rs b/src/utils/tiktoken.rs index eb8819d..940d6f9 100644 --- a/src/utils/tiktoken.rs +++ b/src/utils/tiktoken.rs @@ -1,6 +1,6 @@ //! Use tiktoken for count tokens //! -//! Copy from https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken +//! Copy from [https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken](https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken) #![allow(unused)] @@ -26,7 +26,7 @@ pub fn text_to_tokens(text: &str) -> Vec { } /// Convert tokens to plan text -pub fn tokens_to_text(tokens: Vec) -> Result { +pub fn tokens_to_text(tokens: &[usize]) -> Result { cl100k_base_singleton().lock().decode(tokens) } @@ -43,11 +43,11 @@ pub fn cl100k_base() -> Result { } let mut special_tokens = HashMap::default(); - special_tokens.insert(String::from("<|endoftext|>"), 100257); - special_tokens.insert(String::from("<|fim_prefix|>"), 100258); - special_tokens.insert(String::from("<|fim_middle|>"), 100259); - special_tokens.insert(String::from("<|fim_suffix|>"), 100260); - special_tokens.insert(String::from("<|endofprompt|>"), 100276); + special_tokens.insert(String::from("<|endoftext|>"), 100_257); + special_tokens.insert(String::from("<|fim_prefix|>"), 100_258); + special_tokens.insert(String::from("<|fim_middle|>"), 100_259); + special_tokens.insert(String::from("<|fim_suffix|>"), 100_260); + special_tokens.insert(String::from("<|endofprompt|>"), 100_276); CoreBPE::new( encoder, @@ -64,6 +64,7 @@ pub fn cl100k_base_singleton() -> Arc> { } fn _byte_pair_merge(piece: &[u8], ranks: &HashMap, usize>) -> Vec> { + #[allow(clippy::range_plus_one)] let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); // If you have n parts and m merges, this does O(mn) work @@ -156,11 +157,11 @@ pub struct CoreBPE { } impl CoreBPE { - fn _get_regex(&self) -> &Regex { + const fn _get_regex(&self) -> &Regex { &self.regex } - fn _get_special_regex(&self) -> &Regex { + const fn _get_special_regex(&self) -> &Regex { &self.special_regex } @@ -262,15 +263,12 @@ impl CoreBPE { // Here is a quick and dirty fix: { let token_is_all_space = |token| { - self.decoder - .get(token) - .map(|token_bytes| { - token_bytes - .iter() - .rev() - .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) - }) - .unwrap_or(false) + self.decoder.get(token).map_or(false, |token_bytes| { + token_bytes + .iter() + .rev() + .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + }) }; if last_piece_token_len > 0 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) @@ -342,6 +340,7 @@ impl CoreBPE { && self.sorted_token_bytes[point].starts_with(suffix) { let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + #[allow(clippy::option_if_let_else)] let encoded = match std::str::from_utf8(&possibility) { // Morally, this is byte_pair_encode(&possibility, &self.encoder) // But we might have introduced a regex split which would prevent merges. @@ -386,7 +385,7 @@ impl CoreBPE { if unstable_bytes.len() > 1 { let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); if unstable_bytes.len() - last_decoded.1 > 0 - && last_decoded.0.map_or(false, |c| c.is_whitespace()) + && last_decoded.0.map_or(false, char::is_whitespace) { let mut reencoded = byte_pair_encode( &unstable_bytes[..unstable_bytes.len() - last_decoded.1], @@ -413,11 +412,11 @@ impl CoreBPE { let regex = Regex::new(pattern)?; let special_regex = { - let _parts = special_tokens_encoder + let parts = special_tokens_encoder .keys() .map(|s| fancy_regex::escape(s)) .collect::>(); - Regex::new(&_parts.join("|"))? + Regex::new(&parts.join("|"))? }; let decoder: HashMap> = @@ -434,7 +433,7 @@ impl CoreBPE { let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); sorted_token_bytes.sort(); - Ok(CoreBPE { + Ok(Self { encoder, special_tokens_encoder, decoder, @@ -453,15 +452,15 @@ impl CoreBPE { self._encode_ordinary_native(text) } - pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec { - self._encode_native(text, &allowed_special).0 + pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec { + self._encode_native(text, allowed_special).0 } pub fn encode_with_special_tokens(&self, text: &str) -> Vec { let allowed_special = self .special_tokens_encoder .keys() - .map(|s| s.as_str()) + .map(std::string::String::as_str) .collect(); self._encode_native(text, &allowed_special).0 } @@ -495,9 +494,9 @@ impl CoreBPE { fn encode_with_unstable( &self, text: &str, - allowed_special: HashSet<&str>, + allowed_special: &HashSet<&str>, ) -> (Vec, HashSet>) { - self._encode_unstable_native(text, &allowed_special) + self._encode_unstable_native(text, allowed_special) } #[allow(dead_code)] @@ -525,12 +524,12 @@ impl CoreBPE { // Decoding // ==================== - pub fn decode_bytes(&self, tokens: Vec) -> Vec { - self._decode_native(&tokens) + pub fn decode_bytes(&self, tokens: &[usize]) -> Vec { + self._decode_native(tokens) } - pub fn decode(&self, tokens: Vec) -> Result { - match String::from_utf8(self._decode_native(&tokens)) { + pub fn decode(&self, tokens: &[usize]) -> Result { + match String::from_utf8(self._decode_native(tokens)) { Ok(text) => Ok(text), Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)), } @@ -576,7 +575,7 @@ mod tests { fn cl100k_base_test() { let bpe = cl100k_base().unwrap(); let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces"); - let decoded = bpe.decode(tokens.clone()).unwrap(); + let decoded = bpe.decode(&tokens).unwrap(); assert_eq!(decoded, "This is a test with a lot of spaces"); assert_eq!( tokens,