fix: minor code cleanup and pedantic lints (#134)

pull/136/head
Anthony Rubick 1 year ago committed by GitHub
parent ec51b84290
commit 97fc7de675
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,7 +44,7 @@ jobs:
run: cargo test --all run: cargo test --all
- name: Clippy - name: Clippy
run: cargo clippy --all --all-targets run: cargo clippy --all --all-targets -- -D warnings
- name: Format - name: Format
run: cargo fmt --all --check run: cargo fmt --all --check

@ -1,5 +1,6 @@
use clap::Parser; use clap::Parser;
#[allow(clippy::struct_excessive_bools, clippy::module_name_repetitions)]
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
pub struct Cli { pub struct Cli {

@ -12,6 +12,7 @@ use tokio::time::sleep;
const API_URL: &str = "https://api.openai.com/v1/chat/completions"; const API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)] #[derive(Debug)]
pub struct ChatGptClient { pub struct ChatGptClient {
config: SharedConfig, config: SharedConfig,
@ -106,16 +107,15 @@ impl ChatGptClient {
let chunk = part?.data; let chunk = part?.data;
if chunk == "[DONE]" { if chunk == "[DONE]" {
break; break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
} }
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
} }
Ok(()) Ok(())

@ -43,11 +43,12 @@ impl Conversation {
self.tokens = num_tokens_from_messages(&self.build_emssages("")); self.tokens = num_tokens_from_messages(&self.build_emssages(""));
} }
#[allow(clippy::unnecessary_wraps)]
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
let mut need_add_msg = true; let mut need_add_msg = true;
if self.messages.is_empty() { if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() { if let Some(role) = self.role.as_ref() {
self.messages.extend(role.build_emssages(input)); self.messages.extend(role.build_messages(input));
need_add_msg = false; need_add_msg = false;
} }
} }
@ -67,15 +68,15 @@ impl Conversation {
pub fn echo_messages(&self, content: &str) -> String { pub fn echo_messages(&self, content: &str) -> String {
let messages = self.build_emssages(content); let messages = self.build_emssages(content);
serde_yaml::to_string(&messages).unwrap_or("Unable to echo message".into()) serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
} }
pub fn build_emssages(&self, content: &str) -> Vec<Message> { pub fn build_emssages(&self, content: &str) -> Vec<Message> {
let mut messages = self.messages.to_vec(); let mut messages = self.messages.clone();
let mut need_add_msg = true; let mut need_add_msg = true;
if messages.is_empty() { if messages.is_empty() {
if let Some(role) = self.role.as_ref() { if let Some(role) = self.role.as_ref() {
messages = role.build_emssages(content); messages = role.build_messages(content);
need_add_msg = false; need_add_msg = false;
} }
}; };

@ -17,6 +17,7 @@ impl Message {
} }
} }
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum MessageRole { pub enum MessageRole {
@ -45,6 +46,6 @@ mod tests {
assert_eq!( assert_eq!(
serde_json::to_string(&Message::new("Hello World")).unwrap(), serde_json::to_string(&Message::new("Hello World")).unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}" "{\"role\":\"user\",\"content\":\"Hello World\"}"
) );
} }
} }

@ -45,6 +45,7 @@ const SET_COMPLETIONS: [&str; 8] = [
".set dry_run false", ".set dry_run false",
]; ];
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
#[serde(default)] #[serde(default)]
pub struct Config { pub struct Config {
@ -109,6 +110,7 @@ impl Default for Config {
} }
} }
#[allow(clippy::module_name_repetitions)]
pub type SharedConfig = Arc<RwLock<Config>>; pub type SharedConfig = Arc<RwLock<Config>>;
impl Config { impl Config {
@ -119,7 +121,7 @@ impl Config {
create_config_file(&config_path)?; create_config_file(&config_path)?;
} }
let mut config = if api_key.is_some() && !config_path.exists() { let mut config = if api_key.is_some() && !config_path.exists() {
Default::default() Self::default()
} else { } else {
Self::load_config(&config_path)? Self::load_config(&config_path)?
}; };
@ -156,13 +158,12 @@ impl Config {
pub fn config_dir() -> Result<PathBuf> { pub fn config_dir() -> Result<PathBuf> {
let env_name = get_env_name("config_dir"); let env_name = get_env_name("config_dir");
let path = match env::var_os(env_name) { let path = if let Some(v) = env::var_os(env_name) {
Some(v) => PathBuf::from(v), PathBuf::from(v)
None => { } else {
let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?; let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?;
dir.push(env!("CARGO_CRATE_NAME")); dir.push(env!("CARGO_CRATE_NAME"));
dir dir
}
}; };
Ok(path) Ok(path)
} }
@ -186,18 +187,17 @@ impl Config {
None => { None => {
format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",) format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",)
} }
Some(v) if v.is_temp() => {
format!(
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n",
v.prompt
)
}
Some(v) => { Some(v) => {
if v.is_temp() { format!(
format!( "# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n",
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n", v.name,
v.prompt )
)
} else {
format!(
"# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n",
v.name,
)
}
} }
}; };
file.write_all(output.as_bytes()) file.write_all(output.as_bytes())
@ -216,11 +216,10 @@ impl Config {
pub fn roles_file() -> Result<PathBuf> { pub fn roles_file() -> Result<PathBuf> {
let env_name = get_env_name("roles_file"); let env_name = get_env_name("roles_file");
if let Ok(value) = env::var(env_name) { env::var(env_name).map_or_else(
Ok(PathBuf::from(value)) |_| Self::local_file(ROLES_FILE_NAME),
} else { |value| Ok(PathBuf::from(value)),
Self::local_file(ROLES_FILE_NAME) )
}
} }
pub fn history_file() -> Result<PathBuf> { pub fn history_file() -> Result<PathBuf> {
@ -237,8 +236,8 @@ impl Config {
if let Some(conversation) = self.conversation.as_mut() { if let Some(conversation) = self.conversation.as_mut() {
conversation.update_role(&role)?; conversation.update_role(&role)?;
} }
let output = let output = serde_yaml::to_string(&role)
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into()); .unwrap_or_else(|_| "Unable to echo role details".into());
self.role = Some(role); self.role = Some(role);
Ok(output) Ok(output)
} }
@ -271,6 +270,7 @@ impl Config {
} }
pub fn echo_messages(&self, content: &str) -> String { pub fn echo_messages(&self, content: &str) -> String {
#[allow(clippy::option_if_let_else)]
if let Some(conversation) = self.conversation.as_ref() { if let Some(conversation) = self.conversation.as_ref() {
conversation.echo_messages(content) conversation.echo_messages(content)
} else if let Some(role) = self.role.as_ref() { } else if let Some(role) = self.role.as_ref() {
@ -280,7 +280,7 @@ impl Config {
} }
} }
pub fn get_connect_timeout(&self) -> Duration { pub const fn get_connect_timeout(&self) -> Duration {
Duration::from_secs(self.connect_timeout as u64) Duration::from_secs(self.connect_timeout as u64)
} }
@ -289,10 +289,11 @@ impl Config {
} }
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> { pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
#[allow(clippy::option_if_let_else)]
let messages = if let Some(conversation) = self.conversation.as_ref() { let messages = if let Some(conversation) = self.conversation.as_ref() {
conversation.build_emssages(content) conversation.build_emssages(content)
} else if let Some(role) = self.role.as_ref() { } else if let Some(role) = self.role.as_ref() {
role.build_emssages(content) role.build_messages(content)
} else { } else {
let message = Message::new(content); let message = Message::new(content);
vec![message] vec![message]
@ -314,7 +315,7 @@ impl Config {
Ok(()) Ok(())
} }
pub fn get_reamind_tokens(&self) -> usize { pub const fn get_reamind_tokens(&self) -> usize {
let mut tokens = self.model.1; let mut tokens = self.model.1;
if let Some(conversation) = self.conversation.as_ref() { if let Some(conversation) = self.conversation.as_ref() {
tokens = tokens.saturating_sub(conversation.tokens); tokens = tokens.saturating_sub(conversation.tokens);
@ -330,21 +331,17 @@ impl Config {
let proxy = self let proxy = self
.proxy .proxy
.as_ref() .as_ref()
.map(|v| v.to_string()) .map_or_else(|| String::from("-"), std::string::ToString::to_string);
.unwrap_or("-".into());
let temperature = self let temperature = self
.temperature .temperature
.map(|v| v.to_string()) .map_or_else(|| String::from("-"), |v| v.to_string());
.unwrap_or("-".into());
let (api_key, organization_id) = self.get_api_key(); let (api_key, organization_id) = self.get_api_key();
let api_key = mask_text(&api_key, 3, 4); let api_key = mask_text(&api_key, 3, 4);
let organization_id = organization_id let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4));
.map(|v| mask_text(&v, 3, 4))
.unwrap_or("-".into());
let items = vec![ let items = vec![
("config_file", file_info(&Config::config_file()?)), ("config_file", file_info(&Self::config_file()?)),
("roles_file", file_info(&Config::roles_file()?)), ("roles_file", file_info(&Self::roles_file()?)),
("messages_file", file_info(&Config::messages_file()?)), ("messages_file", file_info(&Self::messages_file()?)),
("api_key", api_key), ("api_key", api_key),
("organization_id", organization_id), ("organization_id", organization_id),
("model", self.model.0.to_string()), ("model", self.model.0.to_string()),
@ -371,8 +368,8 @@ impl Config {
.map(|v| format!(".role {}", v.name)) .map(|v| format!(".role {}", v.name))
.collect(); .collect();
completion.extend(SET_COMPLETIONS.map(|v| v.to_string())); completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(MODELS.map(|(v, _)| format!(".model {}", v))); completion.extend(MODELS.map(|(v, _)| format!(".model {v}")));
completion completion
} }
@ -441,7 +438,7 @@ impl Config {
Ok(()) Ok(())
} }
pub fn get_render_options(&self) -> (bool, bool) { pub const fn get_render_options(&self) -> (bool, bool) {
(self.highlight, self.light_theme) (self.highlight, self.light_theme)
} }
@ -449,13 +446,14 @@ impl Config {
if self.dry_run { if self.dry_run {
if let Ok(messages) = self.build_messages(input) { if let Ok(messages) = self.build_messages(input) {
let tokens = num_tokens_from_messages(&messages); let tokens = num_tokens_from_messages(&messages);
println!(">>> The following message consumes {tokens} tokens.") println!(">>> The following message consumes {tokens} tokens.");
} }
} }
} }
#[allow(clippy::unused_self)] // TODO: do we need to take self here? it's not used in the fn
fn open_message_file(&self) -> Result<File> { fn open_message_file(&self) -> Result<File> {
let path = Config::messages_file()?; let path = Self::messages_file()?;
ensure_parent_exists(&path)?; ensure_parent_exists(&path)?;
OpenOptions::new() OpenOptions::new()
.create(true) .create(true)
@ -468,7 +466,7 @@ impl Config {
let content = read_to_string(config_path) let content = read_to_string(config_path)
.with_context(|| format!("Failed to load config at {}", config_path.display()))?; .with_context(|| format!("Failed to load config at {}", config_path.display()))?;
let config: Config = serde_yaml::from_str(&content) let config: Self = serde_yaml::from_str(&content)
.with_context(|| format!("Invalid config at {}", config_path.display()))?; .with_context(|| format!("Invalid config at {}", config_path.display()))?;
Ok(config) Ok(config)
} }

@ -55,7 +55,7 @@ impl Role {
} }
} }
pub fn build_emssages(&self, content: &str) -> Vec<Message> { pub fn build_messages(&self, content: &str) -> Vec<Message> {
if self.embeded() { if self.embeded() {
let content = merge_prompt_content(&self.prompt, content); let content = merge_prompt_content(&self.prompt, content);
vec![Message { vec![Message {

@ -36,9 +36,9 @@ fn main() -> Result<()> {
exit(0); exit(0);
} }
if cli.list_models { if cli.list_models {
config::MODELS for (name, _) in &config::MODELS {
.iter() println!("{name}");
.for_each(|(name, _)| println!("{}", name)); }
exit(0); exit(0);
} }
if cli.dry_run { if cli.dry_run {
@ -76,18 +76,18 @@ fn main() -> Result<()> {
if let Some(text) = text { if let Some(text) = text {
input = format!("{text}\n{input}"); input = format!("{text}\n{input}");
} }
start_directive(client, config, &input, no_stream) start_directive(&client, &config, &input, no_stream)
} else { } else {
match text { match text {
Some(text) => start_directive(client, config, &text, no_stream), Some(text) => start_directive(&client, &config, &text, no_stream),
None => start_interactive(client, config), None => start_interactive(client, config),
} }
} }
} }
fn start_directive( fn start_directive(
client: ChatGptClient, client: &ChatGptClient,
config: SharedConfig, config: &SharedConfig,
input: &str, input: &str,
no_stream: bool, no_stream: bool,
) -> Result<()> { ) -> Result<()> {
@ -113,7 +113,7 @@ fn start_directive(
abort_clone.set_ctrlc(); abort_clone.set_ctrlc();
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
let output = render_stream(input, &client, config.clone(), false, abort, wg.clone())?; let output = render_stream(input, client, config, false, abort, wg.clone())?;
wg.wait(); wg.wait();
output output
}; };

@ -6,10 +6,11 @@ use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
use anyhow::Result; use anyhow::Result;
use crossbeam::channel::Receiver; use crossbeam::channel::Receiver;
#[allow(clippy::unnecessary_wraps, clippy::module_name_repetitions)]
pub fn cmd_render_stream( pub fn cmd_render_stream(
rx: Receiver<ReplyStreamEvent>, rx: &Receiver<ReplyStreamEvent>,
light_theme: bool, light_theme: bool,
abort: SharedAbortSignal, abort: &SharedAbortSignal,
) -> Result<()> { ) -> Result<()> {
let mut buffer = String::new(); let mut buffer = String::new();
let mut markdown_render = MarkdownRender::new(light_theme); let mut markdown_render = MarkdownRender::new(light_theme);
@ -25,7 +26,7 @@ pub fn cmd_render_stream(
let mut lines: Vec<&str> = text.split('\n').collect(); let mut lines: Vec<&str> = text.split('\n').collect();
buffer = lines.pop().unwrap_or_default().to_string(); buffer = lines.pop().unwrap_or_default().to_string();
let output = lines.join("\n"); let output = lines.join("\n");
print_now!("{}\n", markdown_render.render(&output)) print_now!("{}\n", markdown_render.render(&output));
} else { } else {
buffer = format!("{buffer}{text}"); buffer = format!("{buffer}{text}");
if !(markdown_render.is_code_block() if !(markdown_render.is_code_block()
@ -36,7 +37,7 @@ pub fn cmd_render_stream(
{ {
if let Some((output, remain)) = split_line(&buffer) { if let Some((output, remain)) = split_line(&buffer) {
print_now!("{}", markdown_render.render_line_stateless(&output)); print_now!("{}", markdown_render.render_line_stateless(&output));
buffer = remain buffer = remain;
} }
} }
} }
@ -74,8 +75,8 @@ fn split_line(line: &str) -> Option<(String, String)> {
index += 2; index += 2;
continue; continue;
} }
do_balance(&mut balance, &chars[index..index + 1]); do_balance(&mut balance, &chars[index..=index]);
index += 1 index += 1;
} }
None None
@ -101,25 +102,25 @@ impl Kind {
fn from_chars(chars: &[char]) -> Option<Self> { fn from_chars(chars: &[char]) -> Option<Self> {
let kind = match chars.len() { let kind = match chars.len() {
1 => match chars[0] { 1 => match chars[0] {
'(' => Kind::ParentheseStart, '(' => Self::ParentheseStart,
')' => Kind::ParentheseEnd, ')' => Self::ParentheseEnd,
'[' => Kind::BracketStart, '[' => Self::BracketStart,
']' => Kind::BracketEnd, ']' => Self::BracketEnd,
'*' => Kind::Asterisk, '*' => Self::Asterisk,
'\'' => Kind::SingleQuota, '\'' => Self::SingleQuota,
'"' => Kind::DoubleQuota, '"' => Self::DoubleQuota,
'~' => Kind::Tilde, '~' => Self::Tilde,
'`' => Kind::Backtick, '`' => Self::Backtick,
_ => return None, _ => return None,
}, },
2 if chars[0] == chars[1] => match chars[0] { 2 if chars[0] == chars[1] => match chars[0] {
'*' => Kind::Asterisk2, '*' => Self::Asterisk2,
'~' => Kind::Tilde2, '~' => Self::Tilde2,
_ => return None, _ => return None,
}, },
3 => { 3 => {
if chars == ['`', '`', '`'] { if chars == ['`', '`', '`'] {
Kind::Backtick3 Self::Backtick3
} else { } else {
return None; return None;
} }
@ -131,22 +132,12 @@ impl Kind {
} }
fn do_balance(balance: &mut Vec<Kind>, chars: &[char]) -> bool { fn do_balance(balance: &mut Vec<Kind>, chars: &[char]) -> bool {
if let Some(kind) = Kind::from_chars(chars) { Kind::from_chars(chars).map_or(false, |kind| {
let last = balance.last(); let last = balance.last();
match (kind, last) { match (kind, last) {
(Kind::ParentheseStart | Kind::BracketStart, _) => { (Kind::ParentheseEnd, Some(&Kind::ParentheseStart))
balance.push(kind); | (Kind::BracketEnd, Some(&Kind::BracketStart))
true | (Kind::Asterisk, Some(&Kind::Asterisk))
}
(Kind::ParentheseEnd, Some(&Kind::ParentheseStart)) => {
balance.pop();
true
}
(Kind::BracketEnd, Some(&Kind::BracketStart)) => {
balance.pop();
true
}
(Kind::Asterisk, Some(&Kind::Asterisk))
| (Kind::Asterisk2, Some(&Kind::Asterisk2)) | (Kind::Asterisk2, Some(&Kind::Asterisk2))
| (Kind::SingleQuota, Some(&Kind::SingleQuota)) | (Kind::SingleQuota, Some(&Kind::SingleQuota))
| (Kind::DoubleQuota, Some(&Kind::DoubleQuota)) | (Kind::DoubleQuota, Some(&Kind::DoubleQuota))
@ -157,22 +148,25 @@ fn do_balance(balance: &mut Vec<Kind>, chars: &[char]) -> bool {
balance.pop(); balance.pop();
true true
} }
(Kind::Asterisk, _) (
| (Kind::Asterisk2, _) Kind::ParentheseStart
| (Kind::SingleQuota, _) | Kind::BracketStart
| (Kind::DoubleQuota, _) | Kind::Asterisk
| (Kind::Tilde, _) | Kind::Asterisk2
| (Kind::Tilde2, _) | Kind::SingleQuota
| (Kind::Backtick, _) | Kind::DoubleQuota
| (Kind::Backtick3, _) => { | Kind::Tilde
| Kind::Tilde2
| Kind::Backtick
| Kind::Backtick3,
_,
) => {
balance.push(kind); balance.push(kind);
true true
} }
_ => false, _ => false,
} }
} else { })
false
}
} }
#[cfg(test)] #[cfg(test)]

@ -8,6 +8,7 @@ use syntect::{easy::HighlightLines, parsing::SyntaxReference};
/// Monokai Extended /// Monokai Extended
const MD_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin"); const MD_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin");
const MD_THEME_LIGHT: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin"); const MD_THEME_LIGHT: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin");
#[allow(clippy::doc_markdown)]
/// Comes from https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin /// Comes from https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin
const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin"); const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin");
@ -20,6 +21,7 @@ lazy_static! {
}; };
} }
#[allow(clippy::module_name_repetitions)]
pub struct MarkdownRender { pub struct MarkdownRender {
syntax_set: SyntaxSet, syntax_set: SyntaxSet,
md_theme: Theme, md_theme: Theme,
@ -67,7 +69,7 @@ impl MarkdownRender {
output.unwrap_or_else(|| line.to_string()) output.unwrap_or_else(|| line.to_string())
} }
pub fn is_code_block(&self) -> bool { pub const fn is_code_block(&self) -> bool {
matches!( matches!(
self.prev_line_type, self.prev_line_type,
LineType::CodeBegin | LineType::CodeInner LineType::CodeBegin | LineType::CodeInner
@ -123,13 +125,14 @@ impl MarkdownRender {
} }
fn render_code_line(&self, line: &str) -> Option<String> { fn render_code_line(&self, line: &str) -> Option<String> {
self.code_syntax self.code_syntax.as_ref().map_or_else(
.as_ref() || Some(format!("{}", line.with(self.code_color))),
.map(|syntax| self.render_line_inner(line, syntax)) |syntax| self.render_line_inner(line, syntax),
.unwrap_or_else(|| Some(format!("{}", line.with(self.code_color)))) )
} }
fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> { fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> {
#[allow(clippy::option_if_let_else)]
if let Some(new_lang) = LANGE_MAPS.get(&lang.to_ascii_lowercase()) { if let Some(new_lang) = LANGE_MAPS.get(&lang.to_ascii_lowercase()) {
self.syntax_set.find_syntax_by_name(new_lang) self.syntax_set.find_syntax_by_name(new_lang)
} else { } else {
@ -154,17 +157,17 @@ fn as_terminal_escaped(ranges: &[(Style, &str)]) -> String {
let fg = blend_fg_color(style.foreground, style.background); let fg = blend_fg_color(style.foreground, style.background);
let mut text = text.with(convert_color(fg)); let mut text = text.with(convert_color(fg));
if style.font_style.contains(FontStyle::BOLD) { if style.font_style.contains(FontStyle::BOLD) {
text = text.bold() text = text.bold();
} }
if style.font_style.contains(FontStyle::UNDERLINE) { if style.font_style.contains(FontStyle::UNDERLINE) {
text = text.underlined() text = text.underlined();
} }
output.push_str(&text.to_string()); output.push_str(&text.to_string());
} }
output output
} }
fn convert_color(c: SyntectColor) -> Color { const fn convert_color(c: SyntectColor) -> Color {
Color::Rgb { Color::Rgb {
r: c.r, r: c.r,
g: c.g, g: c.g,
@ -176,14 +179,14 @@ fn blend_fg_color(fg: SyntectColor, bg: SyntectColor) -> SyntectColor {
if fg.a == 0xff { if fg.a == 0xff {
return fg; return fg;
} }
let ratio = fg.a as u32; let ratio = u32::from(fg.a);
let r = (fg.r as u32 * ratio + bg.r as u32 * (255 - ratio)) / 255; let r = (u32::from(fg.r) * ratio + u32::from(bg.r) * (255 - ratio)) / 255;
let g = (fg.g as u32 * ratio + bg.g as u32 * (255 - ratio)) / 255; let g = (u32::from(fg.g) * ratio + u32::from(bg.g) * (255 - ratio)) / 255;
let b = (fg.b as u32 * ratio + bg.b as u32 * (255 - ratio)) / 255; let b = (u32::from(fg.b) * ratio + u32::from(bg.b) * (255 - ratio)) / 255;
SyntectColor { SyntectColor {
r: r as u8, r: u8::try_from(r).unwrap_or(u8::MAX),
g: g as u8, g: u8::try_from(g).unwrap_or(u8::MAX),
b: b as u8, b: u8::try_from(b).unwrap_or(u8::MAX),
a: 255, a: 255,
} }
} }
@ -209,8 +212,7 @@ fn get_code_color(theme: &Theme) -> Color {
}); });
scope scope
.and_then(|v| v.style.foreground) .and_then(|v| v.style.foreground)
.map(convert_color) .map_or_else(|| Color::Yellow, convert_color)
.unwrap_or_else(|| Color::Yellow)
} }
#[cfg(test)] #[cfg(test)]

@ -3,6 +3,7 @@ mod markdown;
mod repl; mod repl;
use self::cmd::cmd_render_stream; use self::cmd::cmd_render_stream;
#[allow(clippy::module_name_repetitions)]
pub use self::markdown::MarkdownRender; pub use self::markdown::MarkdownRender;
use self::repl::repl_render_stream; use self::repl::repl_render_stream;
@ -16,10 +17,11 @@ use crossbeam::channel::unbounded;
use crossbeam::sync::WaitGroup; use crossbeam::sync::WaitGroup;
use std::thread::spawn; use std::thread::spawn;
#[allow(clippy::module_name_repetitions)]
pub fn render_stream( pub fn render_stream(
input: &str, input: &str,
client: &ChatGptClient, client: &ChatGptClient,
config: SharedConfig, config: &SharedConfig,
repl: bool, repl: bool,
abort: SharedAbortSignal, abort: SharedAbortSignal,
wg: WaitGroup, wg: WaitGroup,
@ -30,9 +32,9 @@ pub fn render_stream(
let abort_clone = abort.clone(); let abort_clone = abort.clone();
spawn(move || { spawn(move || {
let err = if repl { let err = if repl {
repl_render_stream(rx, light_theme, abort) repl_render_stream(&rx, light_theme, &abort)
} else { } else {
cmd_render_stream(rx, light_theme, abort) cmd_render_stream(&rx, light_theme, &abort)
}; };
if let Err(err) = err { if let Err(err) = err {
let err = format!("{err:?}"); let err = format!("{err:?}");

@ -16,10 +16,11 @@ use std::{
}; };
use unicode_width::UnicodeWidthStr; use unicode_width::UnicodeWidthStr;
#[allow(clippy::module_name_repetitions)]
pub fn repl_render_stream( pub fn repl_render_stream(
rx: Receiver<ReplyStreamEvent>, rx: &Receiver<ReplyStreamEvent>,
light_theme: bool, light_theme: bool,
abort: SharedAbortSignal, abort: &SharedAbortSignal,
) -> Result<()> { ) -> Result<()> {
enable_raw_mode()?; enable_raw_mode()?;
let mut stdout = io::stdout(); let mut stdout = io::stdout();
@ -32,9 +33,9 @@ pub fn repl_render_stream(
} }
fn repl_render_stream_inner( fn repl_render_stream_inner(
rx: Receiver<ReplyStreamEvent>, rx: &Receiver<ReplyStreamEvent>,
light_theme: bool, light_theme: bool,
abort: SharedAbortSignal, abort: &SharedAbortSignal,
writer: &mut Stdout, writer: &mut Stdout,
) -> Result<()> { ) -> Result<()> {
let mut last_tick = Instant::now(); let mut last_tick = Instant::now();
@ -118,7 +119,8 @@ fn repl_render_stream_inner(
} }
fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> { fn recover_cursor(writer: &mut Stdout, terminal_columns: u16, buffer: &str) -> Result<()> {
let buffer_rows = (buffer.width() as u16 + terminal_columns - 1) / terminal_columns; let buffer_rows = (u16::try_from(buffer.width()).unwrap_or(u16::MAX) + terminal_columns - 1)
/ terminal_columns;
let (_, row) = cursor::position()?; let (_, row) = cursor::position()?;
if buffer_rows == 0 { if buffer_rows == 0 {
queue!(writer, cursor::MoveTo(0, row))?; queue!(writer, cursor::MoveTo(0, row))?;

@ -5,6 +5,7 @@ use std::sync::{
pub type SharedAbortSignal = Arc<AbortSignal>; pub type SharedAbortSignal = Arc<AbortSignal>;
#[allow(clippy::module_name_repetitions)]
pub struct AbortSignal { pub struct AbortSignal {
ctrlc: AtomicBool, ctrlc: AtomicBool,
ctrld: AtomicBool, ctrld: AtomicBool,

@ -24,6 +24,7 @@ pub enum ReplCmd {
Copy, Copy,
} }
#[allow(clippy::module_name_repetitions)]
pub struct ReplCmdHandler { pub struct ReplCmdHandler {
client: ChatGptClient, client: ChatGptClient,
config: SharedConfig, config: SharedConfig,
@ -32,6 +33,7 @@ pub struct ReplCmdHandler {
} }
impl ReplCmdHandler { impl ReplCmdHandler {
#[allow(clippy::unnecessary_wraps)]
pub fn init( pub fn init(
client: ChatGptClient, client: ChatGptClient,
config: SharedConfig, config: SharedConfig,
@ -58,7 +60,7 @@ impl ReplCmdHandler {
let ret = render_stream( let ret = render_stream(
&input, &input,
&self.client, &self.client,
self.config.clone(), &self.config,
true, true,
self.abort.clone(), self.abort.clone(),
wg.clone(), wg.clone(),
@ -113,6 +115,7 @@ impl ReplCmdHandler {
} }
} }
#[allow(clippy::module_name_repetitions)]
pub struct ReplyStreamHandler { pub struct ReplyStreamHandler {
sender: Option<Sender<ReplyStreamEvent>>, sender: Option<Sender<ReplyStreamEvent>>,
buffer: String, buffer: String,
@ -154,22 +157,19 @@ impl ReplyStreamHandler {
} }
pub fn done(&mut self) -> Result<()> { pub fn done(&mut self) -> Result<()> {
match self.sender.as_ref() { if let Some(tx) = self.sender.as_ref() {
Some(tx) => { let ret = tx
let ret = tx .send(ReplyStreamEvent::Done)
.send(ReplyStreamEvent::Done) .with_context(|| "Failed to send StreamEvent:Done");
.with_context(|| "Failed to send StreamEvent:Done"); self.safe_ret(ret)?;
self.safe_ret(ret)?; } else {
if !self.buffer.ends_with('\n') {
print_now!("\n");
} }
None => { if self.repl {
if !self.buffer.ends_with('\n') { print_now!("\n");
print_now!("\n") if cfg!(macos) {
}
if self.repl {
print_now!("\n"); print_now!("\n");
if cfg!(macos) {
print_now!("\n")
}
} }
} }
} }

@ -5,6 +5,7 @@ use reedline::{Highlighter, StyledText};
const MATCH_COLOR: Color = Color::Green; const MATCH_COLOR: Color = Color::Green;
#[allow(clippy::module_name_repetitions)]
pub struct ReplHighlighter { pub struct ReplHighlighter {
external_commands: Vec<String>, external_commands: Vec<String>,
config: SharedConfig, config: SharedConfig,
@ -12,10 +13,10 @@ pub struct ReplHighlighter {
impl ReplHighlighter { impl ReplHighlighter {
/// Construct the default highlighter with a given set of extern commands/keywords to detect and highlight /// Construct the default highlighter with a given set of extern commands/keywords to detect and highlight
pub fn new(config: SharedConfig, external_commands: Vec<String>) -> ReplHighlighter { pub fn new(config: SharedConfig, external_commands: Vec<String>) -> Self {
Self { Self {
config,
external_commands, external_commands,
config,
} }
} }
} }
@ -28,9 +29,10 @@ impl Highlighter for ReplHighlighter {
} else { } else {
Color::White Color::White
}; };
let match_color = match self.config.read().highlight { let match_color = if self.config.read().highlight {
true => MATCH_COLOR, MATCH_COLOR
false => color, } else {
color
}; };
if self if self
@ -45,7 +47,7 @@ impl Highlighter for ReplHighlighter {
.filter(|c| line.contains(*c)) .filter(|c| line.contains(*c))
.map(std::ops::Deref::deref) .map(std::ops::Deref::deref)
.collect(); .collect();
let longest_match = matches.iter().fold("".to_string(), |acc, &item| { let longest_match = matches.iter().fold(String::new(), |acc, &item| {
if item.len() > acc.len() { if item.len() > acc.len() {
item.to_string() item.to_string()
} else { } else {

@ -24,7 +24,7 @@ impl Repl {
.map(|(v, _)| v.to_string()) .map(|(v, _)| v.to_string())
.collect(); .collect();
let completer = Self::create_completer(config.clone(), &commands); let completer = Self::create_completer(&config, &commands);
let highlighter = ReplHighlighter::new(config.clone(), commands); let highlighter = ReplHighlighter::new(config.clone(), commands);
let keybindings = Self::create_keybindings(); let keybindings = Self::create_keybindings();
let history = Self::create_history()?; let history = Self::create_history()?;
@ -45,7 +45,7 @@ impl Repl {
Ok(Self { editor, prompt }) Ok(Self { editor, prompt })
} }
fn create_completer(config: SharedConfig, commands: &[String]) -> DefaultCompleter { fn create_completer(config: &SharedConfig, commands: &[String]) -> DefaultCompleter {
let mut completion = commands.to_vec(); let mut completion = commands.to_vec();
completion.extend(config.read().repl_completions()); completion.extend(config.read().repl_completions());
let mut completer = let mut completer =

@ -55,7 +55,7 @@ impl Repl {
Ok(Signal::Success(line)) => { Ok(Signal::Success(line)) => {
already_ctrlc = false; already_ctrlc = false;
abort.reset(); abort.reset();
match self.handle_line(handler.clone(), line) { match self.handle_line(&handler, &line) {
Ok(quit) => { Ok(quit) => {
if quit { if quit {
break; break;
@ -69,12 +69,11 @@ impl Repl {
} }
Ok(Signal::CtrlC) => { Ok(Signal::CtrlC) => {
abort.set_ctrlc(); abort.set_ctrlc();
if !already_ctrlc { if already_ctrlc {
already_ctrlc = true;
print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n");
} else {
break; break;
} }
already_ctrlc = true;
print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n");
} }
Ok(Signal::CtrlD) => { Ok(Signal::CtrlD) => {
abort.set_ctrld(); abort.set_ctrld();
@ -86,9 +85,9 @@ impl Repl {
Ok(()) Ok(())
} }
fn handle_line(&mut self, handler: Arc<ReplCmdHandler>, line: String) -> Result<bool> { fn handle_line(&mut self, handler: &Arc<ReplCmdHandler>, line: &str) -> Result<bool> {
let line = clean_multiline_symbols(&line); let line = clean_multiline_symbols(line);
match parse_command(&line) { match parse_command(line.as_ref()) {
Some((cmd, args)) => match cmd { Some((cmd, args)) => match cmd {
".exit" => { ".exit" => {
return Ok(true); return Ok(true);
@ -176,7 +175,7 @@ Press Ctrl+C to abort readline, Ctrl+D to exit the REPL
fn clean_multiline_symbols(line: &str) -> Cow<str> { fn clean_multiline_symbols(line: &str) -> Cow<str> {
let trimed_line = line.trim(); let trimed_line = line.trim();
match trimed_line.chars().next() { match trimed_line.chars().next() {
Some('{') | Some('[') | Some('(') => trimed_line[1..trimed_line.len() - 1].into(), Some('{' | '[' | '(') => trimed_line[1..trimed_line.len() - 1].into(),
_ => Cow::Borrowed(line), _ => Cow::Borrowed(line),
} }
} }

@ -9,6 +9,7 @@ const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBl
const INDICATOR_COLOR: Color = Color::Cyan; const INDICATOR_COLOR: Color = Color::Cyan;
const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5); const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5);
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)] #[derive(Clone)]
pub struct ReplPrompt { pub struct ReplPrompt {
config: SharedConfig, config: SharedConfig,
@ -21,7 +22,7 @@ pub struct ReplPrompt {
impl ReplPrompt { impl ReplPrompt {
pub fn new(config: SharedConfig) -> Self { pub fn new(config: SharedConfig) -> Self {
let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) = let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) =
Self::get_colors(config.clone()); Self::get_colors(&config);
Self { Self {
config, config,
prompt_color, prompt_color,
@ -32,14 +33,14 @@ impl ReplPrompt {
} }
pub fn sync_config(&mut self) { pub fn sync_config(&mut self) {
let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) = let (prompt_color, prompt_multiline_color, indicator_color, prompt_right_color) =
Self::get_colors(self.config.clone()); Self::get_colors(&self.config);
self.prompt_color = prompt_color; self.prompt_color = prompt_color;
self.prompt_multiline_color = prompt_multiline_color; self.prompt_multiline_color = prompt_multiline_color;
self.indicator_color = indicator_color; self.indicator_color = indicator_color;
self.prompt_right_color = prompt_right_color; self.prompt_right_color = prompt_right_color;
} }
pub fn get_colors(config: SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) { pub fn get_colors(config: &SharedConfig) -> (Color, nu_ansi_term::Color, Color, Color) {
let (highlight, light_theme) = config.read().get_render_options(); let (highlight, light_theme) = config.read().get_render_options();
if highlight { if highlight {
( (
@ -68,11 +69,11 @@ impl ReplPrompt {
impl Prompt for ReplPrompt { impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<str> { fn render_prompt_left(&self) -> Cow<str> {
if let Some(role) = self.config.read().role.as_ref() { self.config
role.name.to_string().into() .read()
} else { .role
Cow::Borrowed("") .as_ref()
} .map_or(Cow::Borrowed(""), |role| Cow::Owned(role.name.clone()))
} }
fn render_prompt_right(&self) -> Cow<str> { fn render_prompt_right(&self) -> Cow<str> {

@ -1,6 +1,7 @@
use reedline::{ValidationResult, Validator}; use reedline::{ValidationResult, Validator};
/// A default validator which checks for mismatched quotes and brackets /// A default validator which checks for mismatched quotes and brackets
#[allow(clippy::module_name_repetitions)]
pub struct ReplValidator; pub struct ReplValidator;
impl Validator for ReplValidator { impl Validator for ReplValidator {
@ -30,7 +31,7 @@ fn incomplete_brackets(line: &str) -> bool {
None => match c { None => match c {
'{' | '(' => { '{' | '(' => {
balance.push(c); balance.push(c);
symbol = Some(c) symbol = Some(c);
} }
_ => {} _ => {}
}, },

@ -14,7 +14,7 @@ macro_rules! print_now {
}; };
} }
pub fn print_now<T: ToString>(text: T) { pub fn print_now<T: ToString>(text: &T) {
print!("{}", text.to_string()); print!("{}", text.to_string());
let _ = stdout().flush(); let _ = stdout().flush();
} }

@ -1,6 +1,6 @@
//! Use tiktoken for count tokens //! Use tiktoken for count tokens
//! //!
//! Copy from https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken //! Copy from [https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken](https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken)
#![allow(unused)] #![allow(unused)]
@ -26,7 +26,7 @@ pub fn text_to_tokens(text: &str) -> Vec<usize> {
} }
/// Convert tokens to plan text /// Convert tokens to plan text
pub fn tokens_to_text(tokens: Vec<usize>) -> Result<String> { pub fn tokens_to_text(tokens: &[usize]) -> Result<String> {
cl100k_base_singleton().lock().decode(tokens) cl100k_base_singleton().lock().decode(tokens)
} }
@ -43,11 +43,11 @@ pub fn cl100k_base() -> Result<CoreBPE> {
} }
let mut special_tokens = HashMap::default(); let mut special_tokens = HashMap::default();
special_tokens.insert(String::from("<|endoftext|>"), 100257); special_tokens.insert(String::from("<|endoftext|>"), 100_257);
special_tokens.insert(String::from("<|fim_prefix|>"), 100258); special_tokens.insert(String::from("<|fim_prefix|>"), 100_258);
special_tokens.insert(String::from("<|fim_middle|>"), 100259); special_tokens.insert(String::from("<|fim_middle|>"), 100_259);
special_tokens.insert(String::from("<|fim_suffix|>"), 100260); special_tokens.insert(String::from("<|fim_suffix|>"), 100_260);
special_tokens.insert(String::from("<|endofprompt|>"), 100276); special_tokens.insert(String::from("<|endofprompt|>"), 100_276);
CoreBPE::new( CoreBPE::new(
encoder, encoder,
@ -64,6 +64,7 @@ pub fn cl100k_base_singleton() -> Arc<Mutex<CoreBPE>> {
} }
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> { fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
#[allow(clippy::range_plus_one)]
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
// If you have n parts and m merges, this does O(mn) work // If you have n parts and m merges, this does O(mn) work
@ -156,11 +157,11 @@ pub struct CoreBPE {
} }
impl CoreBPE { impl CoreBPE {
fn _get_regex(&self) -> &Regex { const fn _get_regex(&self) -> &Regex {
&self.regex &self.regex
} }
fn _get_special_regex(&self) -> &Regex { const fn _get_special_regex(&self) -> &Regex {
&self.special_regex &self.special_regex
} }
@ -262,15 +263,12 @@ impl CoreBPE {
// Here is a quick and dirty fix: // Here is a quick and dirty fix:
{ {
let token_is_all_space = |token| { let token_is_all_space = |token| {
self.decoder self.decoder.get(token).map_or(false, |token_bytes| {
.get(token) token_bytes
.map(|token_bytes| { .iter()
token_bytes .rev()
.iter() .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
.rev() })
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
.unwrap_or(false)
}; };
if last_piece_token_len > 0 if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
@ -342,6 +340,7 @@ impl CoreBPE {
&& self.sorted_token_bytes[point].starts_with(suffix) && self.sorted_token_bytes[point].starts_with(suffix)
{ {
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
#[allow(clippy::option_if_let_else)]
let encoded = match std::str::from_utf8(&possibility) { let encoded = match std::str::from_utf8(&possibility) {
// Morally, this is byte_pair_encode(&possibility, &self.encoder) // Morally, this is byte_pair_encode(&possibility, &self.encoder)
// But we might have introduced a regex split which would prevent merges. // But we might have introduced a regex split which would prevent merges.
@ -386,7 +385,7 @@ impl CoreBPE {
if unstable_bytes.len() > 1 { if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0 if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.map_or(false, |c| c.is_whitespace()) && last_decoded.0.map_or(false, char::is_whitespace)
{ {
let mut reencoded = byte_pair_encode( let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1], &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
@ -413,11 +412,11 @@ impl CoreBPE {
let regex = Regex::new(pattern)?; let regex = Regex::new(pattern)?;
let special_regex = { let special_regex = {
let _parts = special_tokens_encoder let parts = special_tokens_encoder
.keys() .keys()
.map(|s| fancy_regex::escape(s)) .map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Regex::new(&_parts.join("|"))? Regex::new(&parts.join("|"))?
}; };
let decoder: HashMap<usize, Vec<u8>> = let decoder: HashMap<usize, Vec<u8>> =
@ -434,7 +433,7 @@ impl CoreBPE {
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect(); let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort(); sorted_token_bytes.sort();
Ok(CoreBPE { Ok(Self {
encoder, encoder,
special_tokens_encoder, special_tokens_encoder,
decoder, decoder,
@ -453,15 +452,15 @@ impl CoreBPE {
self._encode_ordinary_native(text) self._encode_ordinary_native(text)
} }
pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> { pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<usize> {
self._encode_native(text, &allowed_special).0 self._encode_native(text, allowed_special).0
} }
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<usize> { pub fn encode_with_special_tokens(&self, text: &str) -> Vec<usize> {
let allowed_special = self let allowed_special = self
.special_tokens_encoder .special_tokens_encoder
.keys() .keys()
.map(|s| s.as_str()) .map(std::string::String::as_str)
.collect(); .collect();
self._encode_native(text, &allowed_special).0 self._encode_native(text, &allowed_special).0
} }
@ -495,9 +494,9 @@ impl CoreBPE {
fn encode_with_unstable( fn encode_with_unstable(
&self, &self,
text: &str, text: &str,
allowed_special: HashSet<&str>, allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) { ) -> (Vec<usize>, HashSet<Vec<usize>>) {
self._encode_unstable_native(text, &allowed_special) self._encode_unstable_native(text, allowed_special)
} }
#[allow(dead_code)] #[allow(dead_code)]
@ -525,12 +524,12 @@ impl CoreBPE {
// Decoding // Decoding
// ==================== // ====================
pub fn decode_bytes(&self, tokens: Vec<usize>) -> Vec<u8> { pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
self._decode_native(&tokens) self._decode_native(tokens)
} }
pub fn decode(&self, tokens: Vec<usize>) -> Result<String> { pub fn decode(&self, tokens: &[usize]) -> Result<String> {
match String::from_utf8(self._decode_native(&tokens)) { match String::from_utf8(self._decode_native(tokens)) {
Ok(text) => Ok(text), Ok(text) => Ok(text),
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)), Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
} }
@ -576,7 +575,7 @@ mod tests {
fn cl100k_base_test() { fn cl100k_base_test() {
let bpe = cl100k_base().unwrap(); let bpe = cl100k_base().unwrap();
let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces"); let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces");
let decoded = bpe.decode(tokens.clone()).unwrap(); let decoded = bpe.decode(&tokens).unwrap();
assert_eq!(decoded, "This is a test with a lot of spaces"); assert_eq!(decoded, "This is a test with a lot of spaces");
assert_eq!( assert_eq!(
tokens, tokens,

Loading…
Cancel
Save