feat: `--file/.file` can load dirs (#693)

pull/694/head
sigoden 2 months ago committed by GitHub
parent a9268b600f
commit 0afe5fa24b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,12 +10,7 @@ use crate::utils::{base64_encode, sha256, AbortSignal};
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use std::{
collections::HashMap,
fs::File,
io::Read,
path::{Path, PathBuf},
};
use std::{collections::HashMap, fs::File, io::Read, path::Path};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
@ -62,58 +57,29 @@ impl Input {
pub async fn from_files(
config: &GlobalConfig,
text: &str,
files: Vec<String>,
paths: Vec<String>,
role: Option<Role>,
) -> Result<Self> {
let mut texts = vec![];
if !text.is_empty() {
texts.push(text.to_string());
};
let mut medias = vec![];
let mut data_urls = HashMap::new();
let files: Vec<_> = files
.iter()
.map(|f| (f, is_image_ext(Path::new(f))))
.collect();
let multi_files = files.iter().filter(|(_, is_image)| !*is_image).count() > 1;
let loaders = config.read().document_loaders.clone();
let spinner = create_spinner("Loading files").await;
for (file_item, is_image) in files {
match resolve_local_file(file_item) {
Some(file_path) => {
if is_image {
let data_url = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media file '{file_item}'"))?;
data_urls.insert(sha256(&data_url), file_path.display().to_string());
medias.push(data_url)
} else {
let text = read_file(&file_path)
.with_context(|| format!("Unable to read file '{file_item}'"))?;
if multi_files {
texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~"));
} else {
texts.push(text);
}
}
}
None => {
if is_image {
medias.push(file_item.to_string())
} else {
let (text, _) = fetch(&loaders, file_item)
.await
.with_context(|| format!("Failed to load '{file_item}'"))?;
if multi_files {
texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~"));
} else {
texts.push(text);
}
}
}
let ret = load_paths(config, paths).await;
spinner.stop();
let (files, medias, data_urls) = ret?;
let files_len = files.len();
if files_len > 0 {
texts.push(String::new());
}
let is_multi_files = files_len > 1;
for (path, contents) in files {
if is_multi_files {
texts.push(format!("`{path}`:\n\n{contents}\n\n"));
} else {
texts.push(contents);
}
}
spinner.stop();
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
Ok(Self {
config: config.clone(),
@ -383,6 +349,48 @@ fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
}
}
async fn load_paths(
config: &GlobalConfig,
paths: Vec<String>,
) -> Result<(Vec<(String, String)>, Vec<String>, HashMap<String, String>)> {
let mut files = vec![];
let mut medias = vec![];
let mut data_urls = HashMap::new();
let loaders = config.read().document_loaders.clone();
let mut local_paths = vec![];
let mut remote_urls = vec![];
for path in paths {
match resolve_local_path(&path) {
Some(v) => local_paths.push(v),
None => remote_urls.push(path),
}
}
let local_files = expand_glob_paths(&local_paths).await?;
for file_path in local_files {
if is_image(&file_path) {
let data_url = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media file '{file_path}'"))?;
data_urls.insert(sha256(&data_url), file_path);
medias.push(data_url)
} else {
let text = read_file(&file_path)
.with_context(|| format!("Unable to read file '{file_path}'"))?;
files.push((file_path, text));
}
}
for file_url in remote_urls {
if is_image(&file_url) {
medias.push(file_url)
} else {
let (text, _) = fetch(&loaders, &file_url)
.await
.with_context(|| format!("Failed to load url '{file_url}'"))?;
files.push((file_url, text));
}
}
Ok((files, medias, data_urls))
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256(&data_url);
@ -395,25 +403,21 @@ pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -
}
}
fn resolve_local_file(file: &str) -> Option<PathBuf> {
if let Ok(true) = URL_RE.is_match(file) {
fn resolve_local_path(path: &str) -> Option<String> {
if let Ok(true) = URL_RE.is_match(path) {
return None;
}
let path = if let (Some(file), Some(home)) = (file.strip_prefix("~/"), dirs::home_dir()) {
let new_path = if let (Some(file), Some(home)) = (path.strip_prefix("~/"), dirs::home_dir()) {
home.join(file)
} else {
std::env::current_dir().ok()?.join(file)
std::env::current_dir().ok()?.join(path)
};
Some(path)
Some(new_path.display().to_string())
}
fn is_image_ext(path: &Path) -> bool {
path.extension()
.map(|v| {
IMAGE_EXTS
.iter()
.any(|ext| *ext == v.to_string_lossy().to_lowercase())
})
fn is_image(path: &str) -> bool {
path_extension(path)
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
.unwrap_or_default()
}

@ -1,8 +1,7 @@
use super::*;
use anyhow::{bail, Context, Result};
use async_recursion::async_recursion;
use std::{collections::HashMap, path::Path};
use anyhow::{Context, Result};
use std::collections::HashMap;
pub const EXTENSION_METADATA: &str = "__extension__";
pub const PATH_METADATA: &str = "__path__";
@ -40,14 +39,7 @@ pub async fn load_path(
loaders: &HashMap<String, String>,
path: &str,
) -> Result<Vec<(String, RagMetadata)>> {
let (path_str, suffixes) = parse_glob(path)?;
let suffixes = if suffixes.is_empty() {
None
} else {
Some(&suffixes)
};
let mut file_paths = vec![];
list_files(&mut file_paths, Path::new(&path_str), suffixes).await?;
let file_paths = expand_glob_paths(&[path]).await?;
let mut output = vec![];
let file_paths_len = file_paths.len();
match file_paths_len {
@ -68,7 +60,7 @@ pub async fn load_file(
loaders: &HashMap<String, String>,
path: &str,
) -> Result<(String, RagMetadata)> {
let extension = get_extension(path);
let extension = path_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into());
match loaders.get(&extension) {
Some(loader_command) => load_with_command(path, &extension, loader_command),
None => load_plain(path, &extension).await,
@ -105,110 +97,3 @@ fn load_with_command(
metadata.insert(EXTENSION_METADATA.into(), DEFAULT_EXTENSION.to_string());
Ok((contents, metadata))
}
pub fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> {
if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) {
let base_path = path_str[..start].to_string();
if let Some(curly_brace_end) = path_str[start..].find('}') {
let end = start + curly_brace_end;
let extensions_str = &path_str[start + 6..end + 1];
let extensions = if extensions_str.starts_with('{') && extensions_str.ends_with('}') {
extensions_str[1..extensions_str.len() - 1]
.split(',')
.map(|s| s.to_string())
.collect::<Vec<String>>()
} else {
bail!("Invalid path '{path_str}'");
};
Ok((base_path, extensions))
} else {
let extensions_str = &path_str[start + 6..];
let extensions = vec![extensions_str.to_string()];
Ok((base_path, extensions))
}
} else if path_str.ends_with("/**") || path_str.ends_with(r"\**") {
Ok((path_str[0..path_str.len() - 3].to_string(), vec![]))
} else {
Ok((path_str.to_string(), vec![]))
}
}
#[async_recursion]
pub async fn list_files(
files: &mut Vec<String>,
entry_path: &Path,
suffixes: Option<&Vec<String>>,
) -> Result<()> {
if !entry_path.exists() {
bail!("Not found: {:?}", entry_path);
}
if entry_path.is_file() {
add_file(files, suffixes, entry_path);
return Ok(());
}
if !entry_path.is_dir() {
bail!("Not a directory: {:?}", entry_path);
}
let mut reader = tokio::fs::read_dir(entry_path).await?;
while let Some(entry) = reader.next_entry().await? {
let path = entry.path();
if path.is_file() {
add_file(files, suffixes, &path);
} else if path.is_dir() {
list_files(files, &path, suffixes).await?;
}
}
Ok(())
}
fn add_file(files: &mut Vec<String>, suffixes: Option<&Vec<String>>, path: &Path) {
if is_valid_extension(suffixes, path) {
files.push(path.display().to_string());
}
}
fn is_valid_extension(suffixes: Option<&Vec<String>>, path: &Path) -> bool {
if let Some(suffixes) = suffixes {
if !suffixes.is_empty() {
if let Some(extension) = path.extension().map(|v| v.to_string_lossy().to_string()) {
return suffixes.contains(&extension);
}
return false;
}
}
true
}
fn get_extension(path: &str) -> String {
Path::new(&path)
.extension()
.map(|v| v.to_string_lossy().to_lowercase())
.unwrap_or_else(|| DEFAULT_EXTENSION.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_glob() {
assert_eq!(parse_glob("dir").unwrap(), ("dir".into(), vec![]));
assert_eq!(parse_glob("dir/**").unwrap(), ("dir".into(), vec![]));
assert_eq!(
parse_glob("dir/file.md").unwrap(),
("dir/file.md".into(), vec![])
);
assert_eq!(
parse_glob("dir/**/*.md").unwrap(),
("dir".into(), vec!["md".into()])
);
assert_eq!(
parse_glob("dir/**/*.{md,txt}").unwrap(),
("dir".into(), vec!["md".into(), "txt".into()])
);
assert_eq!(
parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(),
("C:\\dir".into(), vec!["md".into(), "txt".into()])
);
}
}

@ -2,6 +2,7 @@ mod abort_signal;
mod clipboard;
mod command;
mod crypto;
mod path;
mod prompt_input;
mod render_prompt;
mod request;
@ -11,6 +12,7 @@ pub use self::abort_signal::*;
pub use self::clipboard::set_text;
pub use self::command::*;
pub use self::crypto::*;
pub use self::path::*;
pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt;
pub use self::request::*;
@ -20,10 +22,7 @@ use anyhow::{Context, Result};
use fancy_regex::Regex;
use is_terminal::IsTerminal;
use lazy_static::lazy_static;
use std::{
env,
path::{self, Path, PathBuf},
};
use std::{env, path::PathBuf};
lazy_static! {
pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap();
@ -157,32 +156,6 @@ pub fn dimmed_text(input: &str) -> String {
nu_ansi_term::Style::new().dimmed().paint(input).to_string()
}
pub fn safe_join_path<T1: AsRef<Path>, T2: AsRef<Path>>(
base_path: T1,
sub_path: T2,
) -> Option<PathBuf> {
let base_path = base_path.as_ref();
let sub_path = sub_path.as_ref();
if sub_path.is_absolute() {
return None;
}
let mut joined_path = PathBuf::from(base_path);
for component in sub_path.components() {
if path::Component::ParentDir == component {
return None;
}
joined_path.push(component);
}
if joined_path.starts_with(base_path) {
Some(joined_path)
} else {
None
}
}
pub fn temp_file(prefix: &str, suffix: &str) -> PathBuf {
env::temp_dir().join(format!(
"{}{prefix}{}{suffix}",

@ -0,0 +1,152 @@
use std::path::{Component, Path, PathBuf};
use anyhow::{bail, Result};
pub fn safe_join_path<T1: AsRef<Path>, T2: AsRef<Path>>(
base_path: T1,
sub_path: T2,
) -> Option<PathBuf> {
let base_path = base_path.as_ref();
let sub_path = sub_path.as_ref();
if sub_path.is_absolute() {
return None;
}
let mut joined_path = PathBuf::from(base_path);
for component in sub_path.components() {
if Component::ParentDir == component {
return None;
}
joined_path.push(component);
}
if joined_path.starts_with(base_path) {
Some(joined_path)
} else {
None
}
}
pub async fn expand_glob_paths<T: AsRef<str>>(paths: &[T]) -> Result<Vec<String>> {
let mut new_paths = vec![];
for path in paths {
let (path_str, suffixes) = parse_glob(path.as_ref())?;
let suffixes = if suffixes.is_empty() {
None
} else {
Some(&suffixes)
};
list_files(&mut new_paths, Path::new(&path_str), suffixes).await?;
}
Ok(new_paths)
}
pub fn path_extension(path: &str) -> Option<String> {
Path::new(&path)
.extension()
.map(|v| v.to_string_lossy().to_lowercase())
}
fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> {
if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) {
let base_path = path_str[..start].to_string();
if let Some(curly_brace_end) = path_str[start..].find('}') {
let end = start + curly_brace_end;
let extensions_str = &path_str[start + 6..end + 1];
let extensions = if extensions_str.starts_with('{') && extensions_str.ends_with('}') {
extensions_str[1..extensions_str.len() - 1]
.split(',')
.map(|s| s.to_string())
.collect::<Vec<String>>()
} else {
bail!("Invalid path '{path_str}'");
};
Ok((base_path, extensions))
} else {
let extensions_str = &path_str[start + 6..];
let extensions = vec![extensions_str.to_string()];
Ok((base_path, extensions))
}
} else if path_str.ends_with("/**") || path_str.ends_with(r"\**") {
Ok((path_str[0..path_str.len() - 3].to_string(), vec![]))
} else {
Ok((path_str.to_string(), vec![]))
}
}
#[async_recursion::async_recursion]
async fn list_files(
files: &mut Vec<String>,
entry_path: &Path,
suffixes: Option<&Vec<String>>,
) -> Result<()> {
if !entry_path.exists() {
bail!("Not found: {}", entry_path.display());
}
if entry_path.is_file() {
add_file(files, suffixes, entry_path);
return Ok(());
}
if !entry_path.is_dir() {
bail!("Not a directory: {:?}", entry_path);
}
let mut reader = tokio::fs::read_dir(entry_path).await?;
while let Some(entry) = reader.next_entry().await? {
let path = entry.path();
if path.is_file() {
add_file(files, suffixes, &path);
} else if path.is_dir() {
list_files(files, &path, suffixes).await?;
}
}
Ok(())
}
fn add_file(files: &mut Vec<String>, suffixes: Option<&Vec<String>>, path: &Path) {
if is_valid_extension(suffixes, path) {
let path = path.display().to_string();
if !files.contains(&path) {
files.push(path);
}
}
}
fn is_valid_extension(suffixes: Option<&Vec<String>>, path: &Path) -> bool {
if let Some(suffixes) = suffixes {
if !suffixes.is_empty() {
if let Some(extension) = path.extension().map(|v| v.to_string_lossy().to_string()) {
return suffixes.contains(&extension);
}
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_glob() {
assert_eq!(parse_glob("dir").unwrap(), ("dir".into(), vec![]));
assert_eq!(parse_glob("dir/**").unwrap(), ("dir".into(), vec![]));
assert_eq!(
parse_glob("dir/file.md").unwrap(),
("dir/file.md".into(), vec![])
);
assert_eq!(
parse_glob("dir/**/*.md").unwrap(),
("dir".into(), vec!["md".into()])
);
assert_eq!(
parse_glob("dir/**/*.{md,txt}").unwrap(),
("dir".into(), vec!["md".into(), "txt".into()])
);
assert_eq!(
parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(),
("C:\\dir".into(), vec!["md".into(), "txt".into()])
);
}
}
Loading…
Cancel
Save