diff --git a/src/config.rs b/src/config.rs index 83238a9..3dd9530 100644 --- a/src/config.rs +++ b/src/config.rs @@ -245,7 +245,9 @@ impl Connection { DatabaseType::Sqlite => { let path = self.path.as_ref().map_or( Err(anyhow::anyhow!("type sqlite needs the path field")), - |path| expand_tilde(path).ok_or(anyhow::anyhow!("cannot expand file path")), + |path| { + expand_path(path).ok_or_else(|| anyhow::anyhow!("cannot expand file path")) + }, )?; Ok(format!("sqlite://{path}", path = path.to_str().unwrap())) @@ -275,40 +277,85 @@ pub fn get_app_config_path() -> anyhow::Result { Ok(path) } -fn expand_tilde>(path_user_input: P) -> Option { - let p = path_user_input.as_ref(); - if !p.starts_with("~") { - return Some(p.to_path_buf()); +fn expand_path(path: &Path) -> Option { + let mut expanded_path = PathBuf::new(); + let mut path_iter = path.iter(); + if path.starts_with("~") { + path_iter.next()?; + expanded_path = expanded_path.join(dirs_next::home_dir()?); } - if p == Path::new("~") { - return dirs_next::home_dir(); - } - dirs_next::home_dir().map(|mut h| { - if h == Path::new("/") { - p.strip_prefix("~").unwrap().to_path_buf() + for path in path_iter { + let path = path.to_str()?; + expanded_path = if cfg!(unix) && path.starts_with('$') { + expanded_path.join(std::env::var(path.strip_prefix('$')?).unwrap_or_default()) + } else if cfg!(winddows) && path.starts_with('%') && path.ends_with('%') { + expanded_path + .join(std::env::var(path.strip_prefix('%')?.strip_suffix('%')?).unwrap_or_default()) } else { - h.push(p.strip_prefix("~/").unwrap()); - h + expanded_path.join(path) } - }) + } + Some(expanded_path) } #[cfg(test)] mod test { - use super::{expand_tilde, PathBuf}; + use super::{expand_path, Path, PathBuf}; + use std::env; #[test] - fn test_expand_tilde() { - #[cfg(unix)] - let home = std::env::var("HOME").unwrap(); - #[cfg(windows)] + #[cfg(unix)] + fn test_expand_path() { + let home = env::var("HOME").unwrap(); + let test_env = "baz"; + env::set_var("TEST", test_env); + + assert_eq!( + expand_path(&Path::new("$HOME/foo")), + Some(PathBuf::from(&home).join("foo")) + ); + + assert_eq!( + expand_path(&Path::new("$HOME/foo/$TEST/bar")), + Some(PathBuf::from(&home).join("foo").join(test_env).join("bar")) + ); + + assert_eq!( + expand_path(&Path::new("~/foo")), + Some(PathBuf::from(&home).join("foo")) + ); + + assert_eq!( + expand_path(&Path::new("~/foo/~/bar")), + Some(PathBuf::from(&home).join("foo").join("~").join("bar")) + ); + } + + #[test] + #[cfg(windows)] + fn test_expand_path() { let home = std::env::var("APPDATA").unwrap(); - let projects = PathBuf::from(home).join("Projects"); - assert_eq!(expand_tilde("~/Projects"), Some(projects)); - assert_eq!(expand_tilde("/foo/bar"), Some("/foo/bar".into())); + let test_env = "baz"; + env::set_var("TEST", test_env); + + assert_eq!( + expand_path(Path::new("%APPDATA%/foo").to_path_buf()), + Some(PathBuf::from(home).join("foo")) + ); + + assert_eq!( + expand_path(Path::new("%APPDATA%/foo/%TEST%/bar").to_path_buf()), + Some(PathBuf::from(&home).join("foo").join(test_env).join("bar")) + ); + + assert_eq!( + expand_path(Path::new("~/foo").to_path_buf()), + Some(PathBuf::from(&home).join("foo")) + ); + assert_eq!( - expand_tilde("~alice/projects"), - Some("~alice/projects".into()) + expand_path(Path::new("~/foo/~/bar").to_path_buf()), + Some(PathBuf::from(&home).join("foo").join("~").join("bar")) ); } }