fix: minor code cleanup and pedantic lints (#134)

pull/136/head
Anthony Rubick 12 months ago committed by GitHub
parent ec51b84290
commit 97fc7de675
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,7 +44,7 @@ jobs:
run: cargo test --all
- name: Clippy
run: cargo clippy --all --all-targets
run: cargo clippy --all --all-targets -- -D warnings
- name: Format
run: cargo fmt --all --check

@ -1,5 +1,6 @@
use clap::Parser;
#[allow(clippy::struct_excessive_bools, clippy::module_name_repetitions)]
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Cli {

@ -12,6 +12,7 @@ use tokio::time::sleep;
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct ChatGptClient {
config: SharedConfig,
@ -106,16 +107,15 @@ impl ChatGptClient {
let chunk = part?.data;
if chunk == "[DONE]" {
break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
}
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
}
Ok(())

@ -43,11 +43,12 @@ impl Conversation {
self.tokens = num_tokens_from_messages(&self.build_emssages(""));
}
#[allow(clippy::unnecessary_wraps)]
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() {
self.messages.extend(role.build_emssages(input));
self.messages.extend(role.build_messages(input));
need_add_msg = false;
}
}
@ -67,15 +68,15 @@ impl Conversation {
pub fn echo_messages(&self, content: &str) -> String {
let messages = self.build_emssages(content);
serde_yaml::to_string(&messages).unwrap_or("Unable to echo message".into())
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
let mut messages = self.messages.to_vec();
let mut messages = self.messages.clone();
let mut need_add_msg = true;
if messages.is_empty() {
if let Some(role) = self.role.as_ref() {
messages = role.build_emssages(content);
messages = role.build_messages(content);
need_add_msg = false;
}
};

@ -17,6 +17,7 @@ impl Message {
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
@ -45,6 +46,6 @@ mod tests {
assert_eq!(
serde_json::to_string(&Message::new("Hello World")).unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}"
)
);
}
}

@ -45,6 +45,7 @@ const SET_COMPLETIONS: [&str; 8] = [
".set dry_run false",
];
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
@ -109,6 +110,7 @@ impl Default for Config {
}
}
#[allow(clippy::module_name_repetitions)]
pub type SharedConfig = Arc<RwLock<Config>>;
impl Config {
@ -119,7 +121,7 @@ impl Config {
create_config_file(&config_path)?;
}
let mut config = if api_key.is_some() && !config_path.exists() {
Default::default()
Self::default()
} else {
Self::load_config(&config_path)?
};
@ -156,13 +158,12 @@ impl Config {
pub fn config_dir() -> Result<PathBuf> {
let env_name = get_env_name("config_dir");
let path = match env::var_os(env_name) {
Some(v) => PathBuf::from(v),
None => {
let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?;
dir.push(env!("CARGO_CRATE_NAME"));
dir
}
let path = if let Some(v) = env::var_os(env_name) {
PathBuf::from(v)
} else {
let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?;
dir.push(env!("CARGO_CRATE_NAME"));
dir
};
Ok(path)
}
@ -186,18 +187,17 @@ impl Config {
None => {
format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",)
}
Some(v) if v.is_temp() => {
format!(
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n",
v.prompt
)
}
Some(v) => {
if v.is_temp() {
format!(
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n",
v.prompt
)
} else {
format!(
"# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n",
v.name,
)
}
format!(
"# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n",
v.name,
)
}
};
file.write_all(output.as_bytes())
@ -216,11 +216,10 @@ impl Config {
pub fn roles_file() -> Result<PathBuf> {
let env_name = get_env_name("roles_file");
if let Ok(value) = env::var(env_name) {
Ok(PathBuf::from(value))
} else {
Self::local_file(ROLES_FILE_NAME)
}
env::var(env_name).map_or_else(
|_| Self::local_file(ROLES_FILE_NAME),
|value| Ok(PathBuf::from(value)),
)
}
pub fn history_file() -> Result<PathBuf> {
@ -237,8 +236,8 @@ impl Config {
if let Some(conversation) = self.conversation.as_mut() {
conversation.update_role(&role)?;
}
let output =
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
let output = serde_yaml::to_string(&role)
.unwrap_or_else(|_| "Unable to echo role details".into());
self.role = Some(role);
Ok(output)
}
@ -271,6 +270,7 @@ impl Config {
}
pub fn echo_messages(&self, content: &str) -> String {
#[allow(clippy::option_if_let_else)]
if let Some(conversation) = self.conversation.as_ref() {
conversation.echo_messages(content)
} else if let Some(role) = self.role.as_ref() {
@ -280,7 +280,7 @@ impl Config {
}
}
pub fn get_connect_timeout(&self) -> Duration {
pub const fn get_connect_timeout(&self) -> Duration {
Duration::from_secs(self.connect_timeout as u64)
}
@ -289,10 +289,11 @@ impl Config {
}
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
#[allow(clippy::option_if_let_else)]
let messages = if let Some(conversation) = self.conversation.as_ref() {
conversation.build_emssages(content)
} else if let Some(role) = self.role.as_ref() {
role.build_emssages(content)
role.build_messages(content)
} else {
let message = Message::new(content);
vec![message]
@ -314,7 +315,7 @@ impl Config {
Ok(())
}
pub fn get_reamind_tokens(&self) -> usize {
pub const fn get_reamind_tokens(&self) -> usize {
let mut tokens = self.model.1;
if let Some(conversation) = self.conversation.as_ref() {
tokens = tokens.saturating_sub(conversation.tokens);
@ -330,21 +331,17 @@ impl Config {
let proxy = self
.proxy
.as_ref()
.map(|v| v.to_string())
.unwrap_or("-".into());
.map_or_else(|| String::from("-"), std::string::ToString::to_string);
let temperature = self
.temperature
.map(|v| v.to_string())
.unwrap_or("-".into());
.map_or_else(|| String::from("-"), |v| v.to_string());
let (api_key, organization_id) = self.get_api_key();
let api_key = mask_text(&api_key, 3, 4);
let organization_id = organization_id
.map(|v| mask_text(&v, 3, 4))
.unwrap_or("-".into());
let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4));
let items = vec![
("config_file", file_info(&Config::config_file()?)),
("roles_file", file_info(&Config::roles_file()?)),
("messages_file", file_info(&Config::messages_file()?)),
("config_file", file_info(&Self::config_file()?)),
("roles_file", file_info(&Self::roles_file()?)),
("messages_file", file_info(&Self::messages_file()?)),
("api_key", api_key),
("organization_id", organization_id),
("model", self.model.0.to_string()),
@ -371,8 +368,8 @@ impl Config {
.map(|v| format!(".role {}", v.name))
.collect();
completion.extend(SET_COMPLETIONS.map(|v| v.to_string()));
completion.extend(MODELS.map(|(v, _)| format!(".model {}", v)));
completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(MODELS.map(|(v, _)| format!(".model {v}")));
completion
}
@ -441,7 +438,7 @@ impl Config {
Ok(())
}
pub fn get_render_options(&self) -> (bool, bool) {
pub const fn get_render_options(&self) -> (bool, bool) {
(self.highlight, self.light_theme)
}
@ -449,13 +446,14 @@ impl Config {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
let tokens = num_tokens_from_messages(&messages);
println!(">>> The following message consumes {tokens} tokens.")
println!(">>> The following message consumes {tokens} tokens.");
}
}
}
#[allow(clippy::unused_self)] // TODO: do we need to take self here? it's not used in the fn
fn open_message_file(&self) -> Result<File> {
let path = Config::messages_file()?;
let path = Self::messages_file()?;
ensure_parent_exists(&path)?;
OpenOptions::new()
.create(true)
@ -468,7 +466,7 @@ impl Config {
let content = read_to_string(config_path)
.with_context(|| format!("Failed to load config at {}", config_path.display()))?;
let config: Config = serde_yaml::from_str(&content)
let config: Self = serde_yaml::from_str(&content)
.with_context(|| format!("Invalid config at {}", config_path.display()))?;
Ok(config)
}

@ -55,7 +55,7 @@ impl Role {
}
}
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
pub fn build_messages(&self, content: &str) -> Vec<Message> {
if self.embeded() {
let content = merge_prompt_content(&self.prompt, content);
vec![Message {

@ -36,9 +36,9 @@ fn main() -> Result<()> {
exit(0);
}
if cli.list_models {
config::MODELS
.iter()
.for_each(|(name, _)| println!("{}", name));
for (name, _) in &config::MODELS {
println!("{name}");
}
exit(0);
}
if cli.dry_run {
@ -76,18 +76,18 @@ fn main() -> Result<()> {
if let Some(text) = text {
input = format!("{text}\n{input}");
}
start_directive(client, config, &input, no_stream)
start_directive(&client, &config, &input, no_stream)
} else {
match text {
Some(text) => start_directive(client, config, &text, no_stream),
Some(text) => start_directive(&client, &config, &text, no_stream),
None => start_interactive(client, config),
}
}
}
fn start_directive(
client: ChatGptClient,
config: SharedConfig,
client: &ChatGptClient,
config: &SharedConfig,
input: &str,
no_stream: bool,
) -> Result<()> {
@ -113,7 +113,7 @@ fn start_directive(
abort_clone.set_ctrlc();
})
.expect("Error setting Ctrl-C handler");
let output = render_stream(input, &client, config.clone(), false, abort, wg.clone())?;
let output = render_stream(input, client, config, false, abort, wg.clone())?;
wg.wait();
output
};

@ -6,10 +6,11 @@ use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use anyhow::Result;
use crossbeam::channel::Receiver;
#[allow(clippy::unnecessary_wraps, clippy::module_name_repetitions)]
pub fn cmd_render_stream(
rx: Receiver<ReplyStreamEvent>,
rx: &Receiver<ReplyStreamEvent>,
light_theme: bool,
abort: SharedAbortSignal,
abort: &SharedAbortSignal,
) -> Result<()> {
let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new(light_theme);
@ -25,7 +26,7 @@ pub fn cmd_render_stream(
let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string();
let output = lines.join("\n");
print_now!("{}\n", markdown_render.render(&output))
print_now!("{}\n", markdown_render.render(&output));
} else {
buffer = format!("{buffer}{text}");
if !(markdown_render.is_code_block()
@ -36,7 +37,7 @@ pub fn cmd_render_stream(
{
if let Some((output, remain)) = split_line(&buffer) {
print_now!("{}", markdown_render.render_line_stateless(&output));
buffer = remain
buffer = remain;
}
}
}
@ -74,8 +75,8 @@ fn split_line(line: &str) -> Option<(String, String)> {
index += 2;
continue;
}
do_balance(&mut balance, &chars[index..index + 1]);
index += 1
do_balance(&mut balance, &chars[index..=index]);
index += 1;
}
None
@ -101,25 +102,25 @@ impl Kind {
fn from_chars(chars: &[char]) -> Option<Self> {
let kind = match chars.len() {
1 => match chars[0] {
'(' => Kind::ParentheseStart,
')' => Kind::ParentheseEnd,
'[' => Kind::BracketStart,
']' => Kind::BracketEnd,
'*' => Kind::Asterisk,
'\'' => Kind::SingleQuota,
'"' => Kind::DoubleQuota,
'~' => Kind::Tilde,
'`' => Kind::Backtick,
'(' => Self::ParentheseStart,
')' => Self::ParentheseEnd,
'[' => Self::BracketStart,
']' => Self::BracketEnd,
'*' => Self::Asterisk,
'\'' => Self::SingleQuota,
'"' => Self::DoubleQuota,
'~' => Self::Tilde,
'`' => Self::Backtick,
_ => return None,
},
2 if chars[0] == chars[1] => match chars[0] {
'*' => Kind::Asterisk2,
'~' => Kind::Tilde2,
'*' => Self::Asterisk2,
'~' => Self::Tilde2,
_ => return None,
},
3 => {
if chars == ['`', '`', '`'] {
Kind::Backtick3
Self::Backtick3
} else {
return None;
}
@ -131,22 +132,12 @@ impl Kind {
}
fn do_balance(balance: &mut Vec<Kind>, chars: &[char]) -> bool {
if let Some(kind) = Kind::from_chars(chars) {
Kind::from_chars(chars).map_or(false, |kind| {
let last = balance.last();
match (kind, last) {
(Kind::ParentheseStart | Kind::BracketStart, _) => {
balance.push(kind);
true
}
(Kind::ParentheseEnd, Some(&Kind::ParentheseStart)) => {
balance.pop();
true
}
(Kind::BracketEnd, Some(&Kind::BracketStart)) => {
balance.pop();
true
}
(Kind::Asterisk, Some(&Kind::Asterisk))
(Kind::ParentheseEnd, Some(&Kind::ParentheseStart))
| (Kind::BracketEnd, Some(&Kind::BracketStart))
| (Kind::Asterisk, Some(&Kind::Asterisk))
| (Kind::Asterisk2, Some(&Kind::Asterisk2))
| (Kind::SingleQuota, Some(&Kind::SingleQuota))
| (Kind::DoubleQuota, Some(&Kind::DoubleQuota))
@ -157,22 +148,25 @@ fn do_balance(balance: &mut Vec<Kind>, chars: &[char]) -> bool {
balance.pop();
true
}
(Kind::Asterisk, _)
| (Kind::Asterisk2, _)
| (Kind::SingleQuota, _)
| (Kind::DoubleQuota, _)
| (Kind::Tilde, _)
| (Kind::Tilde2, _)
| (Kind::Backtick, _)
| (Kind::Backtick3, _) => {
(
Kind::ParentheseStart
| Kind::BracketStart
| Kind::Asterisk
| Kind::Asterisk2
| Kind::SingleQuota
| Kind::DoubleQuota
| Kind::Tilde
| Kind::Tilde2
| Kind::Backtick
| Kind::Backtick3,
_,
) => {
balance.push(kind);
true
}
_ => false,
}
} else {
false
}
})
}
#[cfg(test)]

@ -8,6 +8,7 @@ use syntect::{easy::HighlightLines, parsing::SyntaxReference};
/// Monokai Extended
const MD_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin");
const MD_THEME_LIGHT: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin");
#[allow(clippy::doc_markdown)]
/// Comes from https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin
const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin");
@ -20,6 +21,7 @@ lazy_static! {
};
}
#[allow(clippy::module_name_repetitions)]
pub struct MarkdownRender {
syntax_set: SyntaxSet,
md_theme: Theme,
@ -67,7 +69,7 @@ impl MarkdownRender {
output.unwrap_or_else(|| line.to_string())
}
pub fn is_code_block(&self) -> bool {
pub const fn is_code_block(&self) -> bool {
matches!(
self.prev_line_type,
LineType::CodeBegin | LineType::CodeInner
@ -123,13 +125,14 @@ impl MarkdownRender {
}
fn render_code_line(&self, line: &str) -> Option<String> {
self.code_syntax
.as_ref()
.map(|syntax| self.render_line_inner(line, syntax))
.unwrap_or_else(|| Some(format!("{}", line.with(self.code_color))))
self.code_syntax.as_ref().map_or_else(
|| Some(format!("{}", line.with(self.code_color))),
|syntax| self.render_line_inner(line, syntax),
)
}
fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> {
#[allow(clippy::option_if_let_else)]
if let Some(new_lang) = LANGE_MAPS.get(&lang.to_ascii_lowercase()) {
self.syntax_set.find_syntax_by_name(new_lang)
} else {
@ -154,17 +157,17 @@ fn as_terminal_escaped(ranges: &[(Style, &str)]) -> String {
let fg = blend_fg_color(style.foreground, style.background);
let mut text = text.with(convert_color(fg));
if style.font_style.contains(FontStyle::BOLD) {
text = text.bold()
text = text.bold();
}
if style.font_style.contains(FontStyle::UNDERLINE) {
text = text.underlined()
text = text.underlined();
}
output.push_str(&text.to_string());
}
output
}
fn convert_color(c: SyntectColor) -> Color {
const fn convert_color(c: SyntectColor) -> Color {
Color::Rgb {
r: c.r,
g: c.g,
@ -176,14 +179,14 @@ fn blend_fg_color(fg: SyntectColor, bg: SyntectColor) -> SyntectColor {
if fg.a == 0xff {
return fg;
}
let ratio = fg.a as u32;
let r = (fg.r as u32 * ratio + bg.r as u32 * (255 - ratio)) / 255;
let g = (fg.g as u32 * ratio + bg.g as u32 * (255 - ratio)) / 255;
let b = (fg.b as u32 * ratio + bg.b as u32 * (255 - ratio)) / 255;
let ratio = u32::from(fg.a);
let r = (u32::from(fg.r) * ratio + u32::from(bg.r) * (255 - ratio)) / 255;
let g = (u32::from(fg.g) * ratio + u32::from(bg.g) * (255 - ratio)) / 255;
let b = (u32::from(fg.b) * ratio + u32::from(bg.b) * (255 - ratio)) / 255;
SyntectColor {
r: r as u8,
g: g as u8,
b: b as u8,
r: u8::try_from(r).unwrap_or(u8::MAX),
g: u8::try_from(g).unwrap_or(u8::MAX),
b: u8::try_from(b).unwrap_or(u8::MAX),
a: 255,
}
}
@ -209,8 +212,7 @@ fn get_code_color(theme: &Theme) -> Color {
});
scope
.and_then(|v| v.style.foreground)
.map(convert_color)
.unwrap_or_else(|| Color::Yellow)
.map_or_else(|| Color::Yellow, convert_color)
}
#[cfg(test)]

@ -3,6 +3,7 @@ mod markdown;
mod repl;
use self::cmd::cmd_render_stream;
#[allow(clippy::module_name_repetitions)]
pub use self::markdown::MarkdownRender;
use self::repl::repl_render_stream;
@ -16,10 +17,11 @@ use crossbeam::channel::unbounded;
use crossbeam::sync::WaitGroup;
use std::thread::spawn;
#[allow(clippy::module_name_repetitions)]
pub fn render_stream(
input: &str,
client: &ChatGptClient,
config: SharedConfig,
config: &SharedConfig,
repl: bool,
abort: SharedAbortSignal,
wg: WaitGroup,
@ -30,9 +32,9 @@ pub fn render_stream(
let abort_clone = abort.clone();
spawn(move || {
let err = if repl {
repl_render_stream(rx, light_theme, abort)
repl_render_stream(&rx, light_theme, &abort)
} else {
cmd_render_stream(rx, light_theme, abort)
cmd_render_stream(&rx, light_theme, &abort)
};
if let Err(err) = err {
let err = format!("{err:?}");

@ -16,10 +16,11 @@ use std::{
};
use unicode_width::UnicodeWidthStr;
#[allow(clippy::module_name_repetitions)]
pub fn repl_render_stream(
rx: Receiver<ReplyStreamEvent>,
rx: &Receiver<ReplyStreamEvent>,
light_theme: bool,
abort: SharedAbortSignal,
abort: &SharedAbortSignal,
) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
@ -32,9 +33,9 @@ pub fn repl_render_stream(
}
fn repl_render_stream_inner(
rx: Receiver<ReplyStreamEvent>,
rx: &Receiver<ReplyStreamEvent>,
light_theme: bool,
abort: SharedAbortSignal,
abort: &SharedAbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut last_tick = Instant::now();
@ -118,7 +119,8 @@ fn repl_render_stream_inner(
}
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns;
let buffer_rows = (u16::try_from(buffer.width()).unwrap_or(u16::MAX) + terminal_columns - 1)
/ terminal_columns;
let (_, row) = cursor::position()?;
if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?;

@ -5,6 +5,7 @@ use std::sync::{
pub type SharedAbortSignal = Arc<AbortSignal>;
#[allow(clippy::module_name_repetitions)]
pub struct AbortSignal {
ctrlc: AtomicBool,
ctrld: AtomicBool,

@ -24,6 +24,7 @@ pub enum ReplCmd {
Copy,
}
#[allow(clippy::module_name_repetitions)]
pub struct ReplCmdHandler {
client: ChatGptClient,
config: SharedConfig,
@ -32,6 +33,7 @@ pub struct ReplCmdHandler {
}
impl ReplCmdHandler {
#[allow(clippy::unnecessary_wraps)]
pub fn init(
client: ChatGptClient,
config: SharedConfig,
@ -58,7 +60,7 @@ impl ReplCmdHandler {
let ret = render_stream(
&input,
&self.client,
self.config.clone(),
&self.config,
true,
self.abort.clone(),
wg.clone(),
@ -113,6 +115,7 @@ impl ReplCmdHandler {
}
}
#[allow(clippy::module_name_repetitions)]
pub struct ReplyStreamHandler {
sender: Option<Sender<ReplyStreamEvent>>,
buffer: String,
@ -154,22 +157,19 @@ impl ReplyStreamHandler {
}
pub fn done(&mut self) -> Result<()> {
match self.sender.as_ref() {
Some(tx) => {
let ret = tx
.send(ReplyStreamEvent::Done)
.with_context(|| "Failed to send StreamEvent:Done");
self.safe_ret(ret)?;
if let Some(tx) = self.sender.as_ref() {
let ret = tx
.send(ReplyStreamEvent::Done)
.with_context(|| "Failed to send StreamEvent:Done");
self.safe_ret(ret)?;
} else {
if !self.buffer.ends_with('\n') {
print_now!("\n");
}
None => {
if !self.buffer.ends_with('\n') {
print_now!("\n")
}
if self.repl {
if self.repl {
print_now!("\n");
if cfg!(macos) {
print_now!("\n");
if cfg!(macos) {
print_now!("\n")
}
}
}
}

@ -5,6 +5,7 @@ use reedline::{Highlighter, StyledText};
const MATCH_COLOR: Color = Color::Green;
#[allow(clippy::module_name_repetitions)]
pub struct ReplHighlighter {
external_commands: Vec<String>,
config: SharedConfig,
@ -12,10 +13,10 @@ pub struct ReplHighlighter {
impl ReplHighlighter {
/// Construct the default highlighter with a given set of extern commands/keywords to detect and highlight
pub fn new(config: SharedConfig, external_commands: Vec<String>) -> ReplHighlighter {
pub fn new(config: SharedConfig, external_commands: Vec<String>) -> Self {
Self {
config,
external_commands,
config,
}
}
}
@ -28,9 +29,10 @@ impl Highlighter for ReplHighlighter {
} else {
Color::White
};
let match_color = match self.config.read().highlight {
true => MATCH_COLOR,
false => color,
let match_color = if self.config.read().highlight {
MATCH_COLOR
} else {
color
};
if self
@ -45,7 +47,7 @@ impl Highlighter for ReplHighlighter {
.filter(|c| line.contains(*c))
.map(std::ops::Deref::deref)
.collect();
let longest_match = matches.iter().fold("".to_string(), |acc, &item| {
let longest_match = matches.iter().fold(String::new(), |acc, &item| {
if item.len() > acc.len() {
item.to_string()
} else {

@ -24,7 +24,7 @@ impl Repl {
.map(|(v, _)| v.to_string())
.collect();
let completer = Self::create_completer(config.clone(), &commands);
let completer = Self::create_completer(&config, &commands);
let highlighter = ReplHighlighter::new(config.clone(), commands);
let keybindings = Self::create_keybindings();
let history = Self::create_history()?;
@ -45,7 +45,7 @@ impl Repl {
Ok(Self { editor, prompt })
}
fn create_completer(config: SharedConfig, commands: &[String]) -> DefaultCompleter {
fn create_completer(config: &SharedConfig, commands: &[String]) -> DefaultCompleter {
let mut completion = commands.to_vec();
completion.extend(config.read().repl_completions());
let mut completer =

@ -55,7 +55,7 @@ impl Repl {
Ok(Signal::Success(line)) => {
already_ctrlc = false;
abort.reset();
match self.handle_line(handler.clone(), line) {
match self.handle_line(&handler, &line) {
Ok(quit) => {
if quit {
break;
@ -69,12 +69,11 @@ impl Repl {
}
Ok(Signal::CtrlC) => {
abort.set_ctrlc();
if !already_ctrlc {
already_ctrlc = true;
print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n");
} else {
if already_ctrlc {
break;
}
already_ctrlc = true;
print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n");
}
Ok(Signal::CtrlD) => {
abort.set_ctrld();
@ -86,9 +85,9 @@ impl Repl {
Ok(())
}
fn handle_line(&mut self, handler: Arc<ReplCmdHandler>, line: String) -> Result<bool> {
let line = clean_multiline_symbols(&line);
match parse_command(&line) {
fn handle_line(&mut self, handler: &Arc<ReplCmdHandler>, line: &str) -> Result<bool> {
let line = clean_multiline_symbols(line);
match parse_command(line.as_ref()) {
Some((cmd, args)) => match cmd {
".exit" => {
return Ok(true);
@ -176,7 +175,7 @@ Press Ctrl+C to abort readline, Ctrl+D to exit the REPL
fn clean_multiline_symbols(line: &str) -> Cow<str> {
let trimed_line = line.trim();
match trimed_line.chars().next() {
Some('{') | Some('[') | Some('(') => trimed_line[1..trimed_line.len() - 1].into(),
Some('{' | '[' | '(') => trimed_line[1..trimed_line.len() - 1].into(),
_ => Cow::Borrowed(line),
}
}

@ -9,6 +9,7 @@ const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBl
const INDICATOR_COLOR: Color = Color::Cyan;
const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5);
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct ReplPrompt {
config: SharedConfig,
@ -21,7 +22,7 @@ pub struct ReplPrompt {
impl ReplPrompt {
pub fn new(config: SharedConfig) -> Self {
let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) =
Self::get_colors(config.clone());
Self::get_colors(&config);
Self {
config,
prompt_color,
@ -32,14 +33,14 @@ impl ReplPrompt {
}
pub fn sync_config(&mut self) {
let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) =
Self::get_colors(self.config.clone());
Self::get_colors(&self.config);
self.prompt_color = prompt_color;
self.prompt_multiline_color = prompt_multiline_color;
self.indicator_color = indicator_color;
self.prompt_right_color = prompt_right_color;
}
pub fn get_colors(config: SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) {
pub fn get_colors(config: &SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) {
let (highlight, light_theme) = config.read().get_render_options();
if highlight {
(
@ -68,11 +69,11 @@ impl ReplPrompt {
impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<str> {
if let Some(role) = self.config.read().role.as_ref() {
role.name.to_string().into()
} else {
Cow::Borrowed("")
}
self.config
.read()
.role
.as_ref()
.map_or(Cow::Borrowed(""), |role| Cow::Owned(role.name.clone()))
}
fn render_prompt_right(&self) -> Cow<str> {

@ -1,6 +1,7 @@
use reedline::{ValidationResult, Validator};
/// A default validator which checks for mismatched quotes and brackets
#[allow(clippy::module_name_repetitions)]
pub struct ReplValidator;
impl Validator for ReplValidator {
@ -30,7 +31,7 @@ fn incomplete_brackets(line: &str) -> bool {
None => match c {
'{' | '(' => {
balance.push(c);
symbol = Some(c)
symbol = Some(c);
}
_ => {}
},

@ -14,7 +14,7 @@ macro_rules! print_now {
};
}
pub fn print_now<T: ToString>(text: T) {
pub fn print_now<T: ToString>(text: &T) {
print!("{}", text.to_string());
let _ = stdout().flush();
}

@ -1,6 +1,6 @@
//! Use tiktoken for count tokens
//!
//! Copy from https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken
//! Copy from [https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken](https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken)
#![allow(unused)]
@ -26,7 +26,7 @@ pub fn text_to_tokens(text: &str) -> Vec<usize> {
}
/// Convert tokens to plan text
pub fn tokens_to_text(tokens: Vec<usize>) -> Result<String> {
pub fn tokens_to_text(tokens: &[usize]) -> Result<String> {
cl100k_base_singleton().lock().decode(tokens)
}
@ -43,11 +43,11 @@ pub fn cl100k_base() -> Result<CoreBPE> {
}
let mut special_tokens = HashMap::default();
special_tokens.insert(String::from("<|endoftext|>"), 100257);
special_tokens.insert(String::from("<|fim_prefix|>"), 100258);
special_tokens.insert(String::from("<|fim_middle|>"), 100259);
special_tokens.insert(String::from("<|fim_suffix|>"), 100260);
special_tokens.insert(String::from("<|endofprompt|>"), 100276);
special_tokens.insert(String::from("<|endoftext|>"), 100_257);
special_tokens.insert(String::from("<|fim_prefix|>"), 100_258);
special_tokens.insert(String::from("<|fim_middle|>"), 100_259);
special_tokens.insert(String::from("<|fim_suffix|>"), 100_260);
special_tokens.insert(String::from("<|endofprompt|>"), 100_276);
CoreBPE::new(
encoder,
@ -64,6 +64,7 @@ pub fn cl100k_base_singleton() -> Arc<Mutex<CoreBPE>> {
}
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
#[allow(clippy::range_plus_one)]
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
// If you have n parts and m merges, this does O(mn) work
@ -156,11 +157,11 @@ pub struct CoreBPE {
}
impl CoreBPE {
fn _get_regex(&self) -> &Regex {
const fn _get_regex(&self) -> &Regex {
&self.regex
}
fn _get_special_regex(&self) -> &Regex {
const fn _get_special_regex(&self) -> &Regex {
&self.special_regex
}
@ -262,15 +263,12 @@ impl CoreBPE {
// Here is a quick and dirty fix:
{
let token_is_all_space = |token| {
self.decoder
.get(token)
.map(|token_bytes| {
token_bytes
.iter()
.rev()
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
.unwrap_or(false)
self.decoder.get(token).map_or(false, |token_bytes| {
token_bytes
.iter()
.rev()
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
};
if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
@ -342,6 +340,7 @@ impl CoreBPE {
&& self.sorted_token_bytes[point].starts_with(suffix)
{
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
#[allow(clippy::option_if_let_else)]
let encoded = match std::str::from_utf8(&possibility) {
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
// But we might have introduced a regex split which would prevent merges.
@ -386,7 +385,7 @@ impl CoreBPE {
if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
&& last_decoded.0.map_or(false, char::is_whitespace)
{
let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
@ -413,11 +412,11 @@ impl CoreBPE {
let regex = Regex::new(pattern)?;
let special_regex = {
let _parts = special_tokens_encoder
let parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))?
Regex::new(&parts.join("|"))?
};
let decoder: HashMap<usize, Vec<u8>> =
@ -434,7 +433,7 @@ impl CoreBPE {
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
Ok(Self {
encoder,
special_tokens_encoder,
decoder,
@ -453,15 +452,15 @@ impl CoreBPE {
self._encode_ordinary_native(text)
}
pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
self._encode_native(text, &allowed_special).0
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<usize> {
self._encode_native(text, allowed_special).0
}
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<usize> {
let allowed_special = self
.special_tokens_encoder
.keys()
.map(|s| s.as_str())
.map(std::string::String::as_str)
.collect();
self._encode_native(text, &allowed_special).0
}
@ -495,9 +494,9 @@ impl CoreBPE {
fn encode_with_unstable(
&self,
text: &str,
allowed_special: HashSet<&str>,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
self._encode_unstable_native(text, &allowed_special)
self._encode_unstable_native(text, allowed_special)
}
#[allow(dead_code)]
@ -525,12 +524,12 @@ impl CoreBPE {
// Decoding
// ====================
pub fn decode_bytes(&self, tokens: Vec<usize>) -> Vec<u8> {
self._decode_native(&tokens)
pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
self._decode_native(tokens)
}
pub fn decode(&self, tokens: Vec<usize>) -> Result<String> {
match String::from_utf8(self._decode_native(&tokens)) {
pub fn decode(&self, tokens: &[usize]) -> Result<String> {
match String::from_utf8(self._decode_native(tokens)) {
Ok(text) => Ok(text),
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
@ -576,7 +575,7 @@ mod tests {
fn cl100k_base_test() {
let bpe = cl100k_base().unwrap();
let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces");
let decoded = bpe.decode(tokens.clone()).unwrap();
let decoded = bpe.decode(&tokens).unwrap();
assert_eq!(decoded, "This is a test with a lot of spaces");
assert_eq!(
tokens,

Loading…
Cancel
Save