diff --git a/src/client.rs b/src/client.rs index 361f897..29dc592 100644 --- a/src/client.rs +++ b/src/client.rs @@ -155,7 +155,7 @@ impl ChatGptClient { let builder = self .build_client()? .post(API_URL) - .bearer_auth(&self.config.lock().api_key) + .bearer_auth(self.config.lock().get_api_key()) .json(&body); Ok(builder) diff --git a/src/config/mod.rs b/src/config/mod.rs index b438687..a670d00 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -38,24 +38,21 @@ const SET_COMPLETIONS: [&str; 9] = [ ]; #[derive(Debug, Clone, Deserialize)] +#[serde(default)] pub struct Config { /// Openai api key - pub api_key: String, + pub api_key: Option, /// What sampling temperature to use, between 0 and 2 pub temperature: Option, /// Whether to persistently save chat messages - #[serde(default)] pub save: bool, /// Whether to disable highlight - #[serde(default = "highlight_value")] pub highlight: bool, /// Set proxy pub proxy: Option, /// Used only for debugging - #[serde(default)] pub dry_run: bool, /// If set ture, start a conversation immediately upon repl - #[serde(default)] pub conversation_first: bool, /// Predefined roles #[serde(skip)] @@ -68,18 +65,44 @@ pub struct Config { pub conversation: Option, } +impl Default for Config { + fn default() -> Self { + Self { + api_key: None, + temperature: None, + save: false, + highlight: true, + proxy: None, + dry_run: false, + conversation_first: false, + roles: vec![], + role: None, + conversation: None, + } + } +} + pub type SharedConfig = Arc>; impl Config { - pub fn init(is_interactive: bool) -> Result { - let config_path = Config::config_file()?; - if is_interactive && !config_path.exists() { + pub fn init(is_interactive: bool) -> Result { + let api_key = env::var(get_env_name("api_key")).ok(); + let config_path = Self::config_file()?; + if is_interactive && api_key.is_none() && !config_path.exists() { create_config_file(&config_path)?; } - let content = read_to_string(&config_path) - .with_context(|| format!("Failed to load config at {}", config_path.display()))?; - let mut config: Config = serde_yaml::from_str(&content) - .with_context(|| format!("Invalid config at {}", config_path.display()))?; + let mut config = if api_key.is_some() && !config_path.exists() { + Default::default() + } else { + Self::load_config(&config_path)? + }; + if api_key.is_some() { + config.api_key = api_key; + } + if config.api_key.is_none() { + bail!("api_key not set"); + } + config.merge_env_vars(); config.load_roles()?; Ok(config) @@ -97,23 +120,15 @@ impl Config { } pub fn config_dir() -> Result { - let env_name = format!( - "{}_CONFIG_DIR", - env!("CARGO_CRATE_NAME").to_ascii_uppercase() - ); - let path = match env::var(env_name) { - Ok(v) => PathBuf::from(v), - Err(_) => { + let env_name = get_env_name("config_dir"); + let path = match env::var_os(env_name) { + Some(v) => PathBuf::from(v), + None => { let mut dir = dirs::config_dir().ok_or_else(|| anyhow!("Not found config dir"))?; dir.push(env!("CARGO_CRATE_NAME")); dir } }; - if !path.exists() { - create_dir_all(&path).map_err(|err| { - anyhow!("Failed to create config dir at {}, {err}", path.display()) - })?; - } Ok(path) } @@ -158,8 +173,17 @@ impl Config { Self::local_file(CONFIG_FILE_NAME) } + pub fn get_api_key(&self) -> &String { + self.api_key.as_ref().expect("api_key not set") + } + pub fn roles_file() -> Result { - Self::local_file(ROLES_FILE_NAME) + let env_name = get_env_name("roles_file"); + if let Ok(value) = env::var(env_name) { + Ok(PathBuf::from(value)) + } else { + Self::local_file(ROLES_FILE_NAME) + } } pub fn history_file() -> Result { @@ -251,7 +275,7 @@ impl Config { ("config_file", file_info(&Config::config_file()?)), ("roles_file", file_info(&Config::roles_file()?)), ("messages_file", file_info(&Config::messages_file()?)), - ("api_key", self.api_key.clone()), + ("api_key", self.get_api_key().to_string()), ("temperature", temperature), ("save", self.save.to_string()), ("highlight", self.highlight.to_string()), @@ -290,7 +314,7 @@ impl Config { if unset { bail!("Error: Not allowed"); } else { - self.api_key = value.to_string(); + self.api_key = Some(value.to_string()); } } "temperature" => { @@ -353,6 +377,7 @@ impl Config { fn open_message_file(&self) -> Result { let path = Config::messages_file()?; + ensure_parent_exists(&path)?; OpenOptions::new() .create(true) .append(true) @@ -360,6 +385,15 @@ impl Config { .with_context(|| format!("Failed to create/append {}", path.display())) } + fn load_config(config_path: &Path) -> Result { + let content = read_to_string(config_path) + .with_context(|| format!("Failed to load config at {}", config_path.display()))?; + + let config: Config = serde_yaml::from_str(&content) + .with_context(|| format!("Invalid config at {}", config_path.display()))?; + Ok(config) + } + fn load_roles(&mut self) -> Result<()> { let path = Self::roles_file()?; if !path.exists() { @@ -372,6 +406,16 @@ impl Config { self.roles = roles; Ok(()) } + + fn merge_env_vars(&mut self) { + if let Ok(value) = env::var(get_env_name("dry_run")) { + match value.as_str() { + "1" | "true" => self.dry_run = true, + "0" | "false" => self.dry_run = false, + _ => {} + } + } + } } fn create_config_file(config_path: &Path) -> Result<()> { @@ -405,11 +449,33 @@ fn create_config_file(config_path: &Path) -> Result<()> { if ans { raw_config.push_str("save: true\n"); } - + ensure_parent_exists(config_path)?; std::fs::write(config_path, raw_config).with_context(|| "Failed to write to config file")?; Ok(()) } -fn highlight_value() -> bool { - true +fn ensure_parent_exists(path: &Path) -> Result<()> { + if path.exists() { + return Ok(()); + } + let parent = path + .parent() + .ok_or_else(|| anyhow!("Failed to write to {}, No parent path", path.display()))?; + if !parent.exists() { + create_dir_all(parent).with_context(|| { + format!( + "Failed to write {}, Cannot create parent directory", + path.display() + ) + })?; + } + Ok(()) +} + +fn get_env_name(key: &str) -> String { + format!( + "{}_{}", + env!("CARGO_CRATE_NAME").to_ascii_uppercase(), + key.to_ascii_uppercase(), + ) }