From 0afe5fa24b8990991635c7023f0fb89403f55d90 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 10 Jul 2024 07:27:46 +0800 Subject: [PATCH] feat: `--file/.file` can load dirs (#693) --- src/config/input.rs | 126 ++++++++++++++++++------------------ src/rag/loader.rs | 123 ++--------------------------------- src/utils/mod.rs | 33 +--------- src/utils/path.rs | 152 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 224 insertions(+), 210 deletions(-) create mode 100644 src/utils/path.rs diff --git a/src/config/input.rs b/src/config/input.rs index bfc3ea9..ee85c2b 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -10,12 +10,7 @@ use crate::utils::{base64_encode, sha256, AbortSignal}; use anyhow::{bail, Context, Result}; use fancy_regex::Regex; use lazy_static::lazy_static; -use std::{ - collections::HashMap, - fs::File, - io::Read, - path::{Path, PathBuf}, -}; +use std::{collections::HashMap, fs::File, io::Read, path::Path}; use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"]; @@ -62,58 +57,29 @@ impl Input { pub async fn from_files( config: &GlobalConfig, text: &str, - files: Vec, + paths: Vec, role: Option, ) -> Result { let mut texts = vec![]; if !text.is_empty() { texts.push(text.to_string()); }; - let mut medias = vec![]; - let mut data_urls = HashMap::new(); - let files: Vec<_> = files - .iter() - .map(|f| (f, is_image_ext(Path::new(f)))) - .collect(); - let multi_files = files.iter().filter(|(_, is_image)| !*is_image).count() > 1; - let loaders = config.read().document_loaders.clone(); let spinner = create_spinner("Loading files").await; - for (file_item, is_image) in files { - match resolve_local_file(file_item) { - Some(file_path) => { - if is_image { - let data_url = read_media_to_data_url(&file_path) - .with_context(|| format!("Unable to read media file '{file_item}'"))?; - data_urls.insert(sha256(&data_url), file_path.display().to_string()); - medias.push(data_url) - } else { - let text = read_file(&file_path) - .with_context(|| format!("Unable to read file '{file_item}'"))?; - if multi_files { - texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~")); - } else { - texts.push(text); - } - } - } - None => { - if is_image { - medias.push(file_item.to_string()) - } else { - let (text, _) = fetch(&loaders, file_item) - .await - .with_context(|| format!("Failed to load '{file_item}'"))?; - if multi_files { - texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~")); - } else { - texts.push(text); - } - } - } + let ret = load_paths(config, paths).await; + spinner.stop(); + let (files, medias, data_urls) = ret?; + let files_len = files.len(); + if files_len > 0 { + texts.push(String::new()); + } + let is_multi_files = files_len > 1; + for (path, contents) in files { + if is_multi_files { + texts.push(format!("`{path}`:\n\n{contents}\n\n")); + } else { + texts.push(contents); } } - spinner.stop(); - let (role, with_session, with_agent) = resolve_role(&config.read(), role); Ok(Self { config: config.clone(), @@ -383,6 +349,48 @@ fn resolve_role(config: &Config, role: Option) -> (Role, bool, bool) { } } +async fn load_paths( + config: &GlobalConfig, + paths: Vec, +) -> Result<(Vec<(String, String)>, Vec, HashMap)> { + let mut files = vec![]; + let mut medias = vec![]; + let mut data_urls = HashMap::new(); + let loaders = config.read().document_loaders.clone(); + let mut local_paths = vec![]; + let mut remote_urls = vec![]; + for path in paths { + match resolve_local_path(&path) { + Some(v) => local_paths.push(v), + None => remote_urls.push(path), + } + } + let local_files = expand_glob_paths(&local_paths).await?; + for file_path in local_files { + if is_image(&file_path) { + let data_url = read_media_to_data_url(&file_path) + .with_context(|| format!("Unable to read media file '{file_path}'"))?; + data_urls.insert(sha256(&data_url), file_path); + medias.push(data_url) + } else { + let text = read_file(&file_path) + .with_context(|| format!("Unable to read file '{file_path}'"))?; + files.push((file_path, text)); + } + } + for file_url in remote_urls { + if is_image(&file_url) { + medias.push(file_url) + } else { + let (text, _) = fetch(&loaders, &file_url) + .await + .with_context(|| format!("Failed to load url '{file_url}'"))?; + files.push((file_url, text)); + } + } + Ok((files, medias, data_urls)) +} + pub fn resolve_data_url(data_urls: &HashMap, data_url: String) -> String { if data_url.starts_with("data:") { let hash = sha256(&data_url); @@ -395,25 +403,21 @@ pub fn resolve_data_url(data_urls: &HashMap, data_url: String) - } } -fn resolve_local_file(file: &str) -> Option { - if let Ok(true) = URL_RE.is_match(file) { +fn resolve_local_path(path: &str) -> Option { + if let Ok(true) = URL_RE.is_match(path) { return None; } - let path = if let (Some(file), Some(home)) = (file.strip_prefix("~/"), dirs::home_dir()) { + let new_path = if let (Some(file), Some(home)) = (path.strip_prefix("~/"), dirs::home_dir()) { home.join(file) } else { - std::env::current_dir().ok()?.join(file) + std::env::current_dir().ok()?.join(path) }; - Some(path) + Some(new_path.display().to_string()) } -fn is_image_ext(path: &Path) -> bool { - path.extension() - .map(|v| { - IMAGE_EXTS - .iter() - .any(|ext| *ext == v.to_string_lossy().to_lowercase()) - }) +fn is_image(path: &str) -> bool { + path_extension(path) + .map(|v| IMAGE_EXTS.contains(&v.as_str())) .unwrap_or_default() } diff --git a/src/rag/loader.rs b/src/rag/loader.rs index 7fa1fd6..0b0d935 100644 --- a/src/rag/loader.rs +++ b/src/rag/loader.rs @@ -1,8 +1,7 @@ use super::*; -use anyhow::{bail, Context, Result}; -use async_recursion::async_recursion; -use std::{collections::HashMap, path::Path}; +use anyhow::{Context, Result}; +use std::collections::HashMap; pub const EXTENSION_METADATA: &str = "__extension__"; pub const PATH_METADATA: &str = "__path__"; @@ -40,14 +39,7 @@ pub async fn load_path( loaders: &HashMap, path: &str, ) -> Result> { - let (path_str, suffixes) = parse_glob(path)?; - let suffixes = if suffixes.is_empty() { - None - } else { - Some(&suffixes) - }; - let mut file_paths = vec![]; - list_files(&mut file_paths, Path::new(&path_str), suffixes).await?; + let file_paths = expand_glob_paths(&[path]).await?; let mut output = vec![]; let file_paths_len = file_paths.len(); match file_paths_len { @@ -68,7 +60,7 @@ pub async fn load_file( loaders: &HashMap, path: &str, ) -> Result<(String, RagMetadata)> { - let extension = get_extension(path); + let extension = path_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into()); match loaders.get(&extension) { Some(loader_command) => load_with_command(path, &extension, loader_command), None => load_plain(path, &extension).await, @@ -105,110 +97,3 @@ fn load_with_command( metadata.insert(EXTENSION_METADATA.into(), DEFAULT_EXTENSION.to_string()); Ok((contents, metadata)) } - -pub fn parse_glob(path_str: &str) -> Result<(String, Vec)> { - if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) { - let base_path = path_str[..start].to_string(); - if let Some(curly_brace_end) = path_str[start..].find('}') { - let end = start + curly_brace_end; - let extensions_str = &path_str[start + 6..end + 1]; - let extensions = if extensions_str.starts_with('{') && extensions_str.ends_with('}') { - extensions_str[1..extensions_str.len() - 1] - .split(',') - .map(|s| s.to_string()) - .collect::>() - } else { - bail!("Invalid path '{path_str}'"); - }; - Ok((base_path, extensions)) - } else { - let extensions_str = &path_str[start + 6..]; - let extensions = vec![extensions_str.to_string()]; - Ok((base_path, extensions)) - } - } else if path_str.ends_with("/**") || path_str.ends_with(r"\**") { - Ok((path_str[0..path_str.len() - 3].to_string(), vec![])) - } else { - Ok((path_str.to_string(), vec![])) - } -} - -#[async_recursion] -pub async fn list_files( - files: &mut Vec, - entry_path: &Path, - suffixes: Option<&Vec>, -) -> Result<()> { - if !entry_path.exists() { - bail!("Not found: {:?}", entry_path); - } - if entry_path.is_file() { - add_file(files, suffixes, entry_path); - return Ok(()); - } - if !entry_path.is_dir() { - bail!("Not a directory: {:?}", entry_path); - } - let mut reader = tokio::fs::read_dir(entry_path).await?; - while let Some(entry) = reader.next_entry().await? { - let path = entry.path(); - if path.is_file() { - add_file(files, suffixes, &path); - } else if path.is_dir() { - list_files(files, &path, suffixes).await?; - } - } - Ok(()) -} - -fn add_file(files: &mut Vec, suffixes: Option<&Vec>, path: &Path) { - if is_valid_extension(suffixes, path) { - files.push(path.display().to_string()); - } -} - -fn is_valid_extension(suffixes: Option<&Vec>, path: &Path) -> bool { - if let Some(suffixes) = suffixes { - if !suffixes.is_empty() { - if let Some(extension) = path.extension().map(|v| v.to_string_lossy().to_string()) { - return suffixes.contains(&extension); - } - return false; - } - } - true -} - -fn get_extension(path: &str) -> String { - Path::new(&path) - .extension() - .map(|v| v.to_string_lossy().to_lowercase()) - .unwrap_or_else(|| DEFAULT_EXTENSION.to_string()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_glob() { - assert_eq!(parse_glob("dir").unwrap(), ("dir".into(), vec![])); - assert_eq!(parse_glob("dir/**").unwrap(), ("dir".into(), vec![])); - assert_eq!( - parse_glob("dir/file.md").unwrap(), - ("dir/file.md".into(), vec![]) - ); - assert_eq!( - parse_glob("dir/**/*.md").unwrap(), - ("dir".into(), vec!["md".into()]) - ); - assert_eq!( - parse_glob("dir/**/*.{md,txt}").unwrap(), - ("dir".into(), vec!["md".into(), "txt".into()]) - ); - assert_eq!( - parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(), - ("C:\\dir".into(), vec!["md".into(), "txt".into()]) - ); - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 548c6f1..40967ab 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,6 +2,7 @@ mod abort_signal; mod clipboard; mod command; mod crypto; +mod path; mod prompt_input; mod render_prompt; mod request; @@ -11,6 +12,7 @@ pub use self::abort_signal::*; pub use self::clipboard::set_text; pub use self::command::*; pub use self::crypto::*; +pub use self::path::*; pub use self::prompt_input::*; pub use self::render_prompt::render_prompt; pub use self::request::*; @@ -20,10 +22,7 @@ use anyhow::{Context, Result}; use fancy_regex::Regex; use is_terminal::IsTerminal; use lazy_static::lazy_static; -use std::{ - env, - path::{self, Path, PathBuf}, -}; +use std::{env, path::PathBuf}; lazy_static! { pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap(); @@ -157,32 +156,6 @@ pub fn dimmed_text(input: &str) -> String { nu_ansi_term::Style::new().dimmed().paint(input).to_string() } -pub fn safe_join_path, T2: AsRef>( - base_path: T1, - sub_path: T2, -) -> Option { - let base_path = base_path.as_ref(); - let sub_path = sub_path.as_ref(); - if sub_path.is_absolute() { - return None; - } - - let mut joined_path = PathBuf::from(base_path); - - for component in sub_path.components() { - if path::Component::ParentDir == component { - return None; - } - joined_path.push(component); - } - - if joined_path.starts_with(base_path) { - Some(joined_path) - } else { - None - } -} - pub fn temp_file(prefix: &str, suffix: &str) -> PathBuf { env::temp_dir().join(format!( "{}{prefix}{}{suffix}", diff --git a/src/utils/path.rs b/src/utils/path.rs new file mode 100644 index 0000000..4e66337 --- /dev/null +++ b/src/utils/path.rs @@ -0,0 +1,152 @@ +use std::path::{Component, Path, PathBuf}; + +use anyhow::{bail, Result}; + +pub fn safe_join_path, T2: AsRef>( + base_path: T1, + sub_path: T2, +) -> Option { + let base_path = base_path.as_ref(); + let sub_path = sub_path.as_ref(); + if sub_path.is_absolute() { + return None; + } + + let mut joined_path = PathBuf::from(base_path); + + for component in sub_path.components() { + if Component::ParentDir == component { + return None; + } + joined_path.push(component); + } + + if joined_path.starts_with(base_path) { + Some(joined_path) + } else { + None + } +} + +pub async fn expand_glob_paths>(paths: &[T]) -> Result> { + let mut new_paths = vec![]; + for path in paths { + let (path_str, suffixes) = parse_glob(path.as_ref())?; + let suffixes = if suffixes.is_empty() { + None + } else { + Some(&suffixes) + }; + list_files(&mut new_paths, Path::new(&path_str), suffixes).await?; + } + Ok(new_paths) +} + +pub fn path_extension(path: &str) -> Option { + Path::new(&path) + .extension() + .map(|v| v.to_string_lossy().to_lowercase()) +} + +fn parse_glob(path_str: &str) -> Result<(String, Vec)> { + if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) { + let base_path = path_str[..start].to_string(); + if let Some(curly_brace_end) = path_str[start..].find('}') { + let end = start + curly_brace_end; + let extensions_str = &path_str[start + 6..end + 1]; + let extensions = if extensions_str.starts_with('{') && extensions_str.ends_with('}') { + extensions_str[1..extensions_str.len() - 1] + .split(',') + .map(|s| s.to_string()) + .collect::>() + } else { + bail!("Invalid path '{path_str}'"); + }; + Ok((base_path, extensions)) + } else { + let extensions_str = &path_str[start + 6..]; + let extensions = vec![extensions_str.to_string()]; + Ok((base_path, extensions)) + } + } else if path_str.ends_with("/**") || path_str.ends_with(r"\**") { + Ok((path_str[0..path_str.len() - 3].to_string(), vec![])) + } else { + Ok((path_str.to_string(), vec![])) + } +} + +#[async_recursion::async_recursion] +async fn list_files( + files: &mut Vec, + entry_path: &Path, + suffixes: Option<&Vec>, +) -> Result<()> { + if !entry_path.exists() { + bail!("Not found: {}", entry_path.display()); + } + if entry_path.is_file() { + add_file(files, suffixes, entry_path); + return Ok(()); + } + if !entry_path.is_dir() { + bail!("Not a directory: {:?}", entry_path); + } + let mut reader = tokio::fs::read_dir(entry_path).await?; + while let Some(entry) = reader.next_entry().await? { + let path = entry.path(); + if path.is_file() { + add_file(files, suffixes, &path); + } else if path.is_dir() { + list_files(files, &path, suffixes).await?; + } + } + Ok(()) +} + +fn add_file(files: &mut Vec, suffixes: Option<&Vec>, path: &Path) { + if is_valid_extension(suffixes, path) { + let path = path.display().to_string(); + if !files.contains(&path) { + files.push(path); + } + } +} + +fn is_valid_extension(suffixes: Option<&Vec>, path: &Path) -> bool { + if let Some(suffixes) = suffixes { + if !suffixes.is_empty() { + if let Some(extension) = path.extension().map(|v| v.to_string_lossy().to_string()) { + return suffixes.contains(&extension); + } + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_glob() { + assert_eq!(parse_glob("dir").unwrap(), ("dir".into(), vec![])); + assert_eq!(parse_glob("dir/**").unwrap(), ("dir".into(), vec![])); + assert_eq!( + parse_glob("dir/file.md").unwrap(), + ("dir/file.md".into(), vec![]) + ); + assert_eq!( + parse_glob("dir/**/*.md").unwrap(), + ("dir".into(), vec!["md".into()]) + ); + assert_eq!( + parse_glob("dir/**/*.{md,txt}").unwrap(), + ("dir".into(), vec!["md".into(), "txt".into()]) + ); + assert_eq!( + parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(), + ("C:\\dir".into(), vec!["md".into(), "txt".into()]) + ); + } +}