mirror of https://github.com/sigoden/aichat
refactor: clients/* and config.rs (#193)
- add register_clients macro to make it easier to add a new client - no create_client_config, just add const PROMPTS - move ModelInfo from clients/ to config/ - model's max_tokens are optional - improve code quanity on config/mod.rs - add/use macro config_get_fnpull/194/head
parent
64202758ec
commit
7f2210dbca
@ -0,0 +1,314 @@
|
||||
use crate::{
|
||||
config::{Message, SharedConfig},
|
||||
repl::{ReplyStreamHandler, SharedAbortSignal},
|
||||
utils::{init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, PromptKind},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::{env, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::{openai::OpenAIConfig, ClientConfig};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! register_role {
|
||||
(
|
||||
$(($name:literal, $config_key:ident, $config:ident, $client:ident),)+
|
||||
) => {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
$(
|
||||
#[serde(rename = $name)]
|
||||
$config_key($config),
|
||||
)+
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: SharedConfig,
|
||||
config: $config,
|
||||
model_info: ModelInfo,
|
||||
}
|
||||
|
||||
impl $client {
|
||||
pub const NAME: &str = $name;
|
||||
|
||||
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
|
||||
let model_info = global_config.read().model_info.clone();
|
||||
let config = {
|
||||
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(Box::new(Self {
|
||||
global_config,
|
||||
config,
|
||||
model_info,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn name(local_config: &$config) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
}
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
None
|
||||
$(.or_else(|| $client::init(config.clone())))+
|
||||
.ok_or_else(|| {
|
||||
let model_info = config.read().model_info.clone();
|
||||
anyhow!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_client_types() -> Vec<&'static str> {
|
||||
vec![$($client::NAME,)+]
|
||||
}
|
||||
|
||||
pub fn create_client_config(client: &str) -> Result<Value> {
|
||||
$(
|
||||
if client == $client::NAME {
|
||||
return create_config(&$client::PROMPTS, $client::NAME)
|
||||
}
|
||||
)+
|
||||
bail!("Unknown client {}", client)
|
||||
}
|
||||
|
||||
pub fn all_models(config: &Config) -> Vec<ModelInfo> {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, v)| match v {
|
||||
$(ClientConfig::$config_key(c) => $client::list_models(c, i),)+
|
||||
ClientConfig::Unknown => vec![],
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! config_get_fn {
|
||||
($field_name:ident, $fn_name:ident) => {
|
||||
fn $fn_name(&self) -> Result<String> {
|
||||
let api_key = self.config.$field_name.clone();
|
||||
api_key
|
||||
.or_else(|| {
|
||||
let env_prefix = Self::name(&self.config);
|
||||
let env_name =
|
||||
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
||||
env::var(&env_name).ok()
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Client {
|
||||
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>);
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let options = self.config().1;
|
||||
let timeout = options
|
||||
.as_ref()
|
||||
.and_then(|v| v.connect_timeout)
|
||||
.unwrap_or(10);
|
||||
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
||||
builder = set_proxy(builder, &proxy)?;
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn send_message(&self, content: &str) -> Result<String> {
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
let global_config = self.config().0;
|
||||
if global_config.read().dry_run {
|
||||
let content = global_config.read().echo_messages(content);
|
||||
return Ok(content);
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = global_config.read().prepare_send_data(content, false)?;
|
||||
self.send_message_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to get awswer")
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message_streaming(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
async fn watch_abort(abort: SharedAbortSignal) {
|
||||
loop {
|
||||
if abort.aborted() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
let abort = handler.get_abort();
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
let global_config = self.config().0;
|
||||
if global_config.read().dry_run {
|
||||
let content = global_config.read().echo_messages(content);
|
||||
let tokens = tokenize(&content);
|
||||
for token in tokens {
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
handler.text(&token)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = global_config.read().prepare_send_data(content, true)?;
|
||||
self.send_message_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done()?;
|
||||
ret.with_context(|| "Failed to get awswer")
|
||||
}
|
||||
_ = watch_abort(abort.clone()) => {
|
||||
handler.done()?;
|
||||
Ok(())
|
||||
},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
abort.set_ctrlc();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self::OpenAI(OpenAIConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ExtraConfig {
|
||||
pub proxy: Option<String>,
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SendData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind);
|
||||
|
||||
pub fn create_config(list: &[PromptType], client: &str) -> Result<Value> {
|
||||
let mut config = json!({
|
||||
"type": client,
|
||||
});
|
||||
for (path, desc, required, kind) in list {
|
||||
match kind {
|
||||
PromptKind::String => {
|
||||
let value = prompt_input_string(desc, *required)?;
|
||||
set_config_value(&mut config, path, kind, &value);
|
||||
}
|
||||
PromptKind::Integer => {
|
||||
let value = prompt_input_integer(desc, *required)?;
|
||||
set_config_value(&mut config, path, kind, &value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let clients = json!(vec![config]);
|
||||
Ok(clients)
|
||||
}
|
||||
|
||||
fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
|
||||
let segs: Vec<&str> = path.split('.').collect();
|
||||
match segs.as_slice() {
|
||||
[name] => json[name] = to_json(kind, value),
|
||||
[scope, name] => match scope.split_once('[') {
|
||||
None => {
|
||||
if json.get(scope).is_none() {
|
||||
let mut obj = json!({});
|
||||
obj[name] = to_json(kind, value);
|
||||
json[scope] = obj;
|
||||
} else {
|
||||
json[scope][name] = to_json(kind, value);
|
||||
}
|
||||
}
|
||||
Some((scope, _)) => {
|
||||
if json.get(scope).is_none() {
|
||||
let mut obj = json!({});
|
||||
obj[name] = to_json(kind, value);
|
||||
json[scope] = json!([obj]);
|
||||
} else {
|
||||
json[scope][0][name] = to_json(kind, value);
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_json(kind: &PromptKind, value: &str) -> Value {
|
||||
if value.is_empty() {
|
||||
return Value::Null;
|
||||
}
|
||||
match kind {
|
||||
PromptKind::String => value.into(),
|
||||
PromptKind::Integer => match value.parse::<i32>() {
|
||||
Ok(value) => value.into(),
|
||||
Err(_) => value.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
|
||||
let proxy = if let Some(proxy) = proxy {
|
||||
if proxy.is_empty() || proxy == "false" || proxy == "-" {
|
||||
return Ok(builder);
|
||||
}
|
||||
proxy.clone()
|
||||
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
|
||||
proxy
|
||||
} else {
|
||||
return Ok(builder);
|
||||
};
|
||||
let builder =
|
||||
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
Ok(builder)
|
||||
}
|
@ -1,272 +1,32 @@
|
||||
#[macro_use]
|
||||
mod common;
|
||||
|
||||
pub mod azure_openai;
|
||||
pub mod localai;
|
||||
pub mod openai;
|
||||
|
||||
use self::{
|
||||
azure_openai::{AzureOpenAIClient, AzureOpenAIConfig},
|
||||
localai::LocalAIConfig,
|
||||
openai::{OpenAIClient, OpenAIConfig},
|
||||
};
|
||||
pub use common::*;
|
||||
|
||||
use self::azure_openai::AzureOpenAIConfig;
|
||||
use self::localai::LocalAIConfig;
|
||||
use self::openai::OpenAIConfig;
|
||||
|
||||
use crate::{
|
||||
client::localai::LocalAIClient,
|
||||
config::{Config, Message, SharedConfig},
|
||||
repl::{ReplyStreamHandler, SharedAbortSignal},
|
||||
utils::tokenize,
|
||||
config::{Config, ModelInfo, SharedConfig},
|
||||
utils::PromptKind,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use inquire::{required, Text};
|
||||
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use serde::Deserialize;
|
||||
use std::{env, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
#[serde(rename = "openai")]
|
||||
OpenAI(OpenAIConfig),
|
||||
#[serde(rename = "localai")]
|
||||
LocalAI(LocalAIConfig),
|
||||
#[serde(rename = "azure-openai")]
|
||||
AzureOpenAI(AzureOpenAIConfig),
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
pub client: String,
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl Default for ModelInfo {
|
||||
fn default() -> Self {
|
||||
OpenAIClient::list_models(&OpenAIConfig::default(), 0)[0].clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self {
|
||||
Self {
|
||||
client: client.into(),
|
||||
name: name.into(),
|
||||
max_tokens,
|
||||
index,
|
||||
}
|
||||
}
|
||||
pub fn stringify(&self) -> String {
|
||||
format!("{}:{}", self.client, self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SendData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub stream: bool,
|
||||
}
|
||||
#[async_trait]
|
||||
pub trait Client {
|
||||
fn config(&self) -> &SharedConfig;
|
||||
|
||||
fn extra_config(&self) -> &Option<ExtraConfig>;
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let options = self.extra_config();
|
||||
let timeout = options
|
||||
.as_ref()
|
||||
.and_then(|v| v.connect_timeout)
|
||||
.unwrap_or(10);
|
||||
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
||||
builder = set_proxy(builder, &proxy)?;
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn send_message(&self, content: &str) -> Result<String> {
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
if self.config().read().dry_run {
|
||||
let content = self.config().read().echo_messages(content);
|
||||
return Ok(content);
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = self.config().read().prepare_send_data(content, false)?;
|
||||
self.send_message_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch")
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message_streaming(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
async fn watch_abort(abort: SharedAbortSignal) {
|
||||
loop {
|
||||
if abort.aborted() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
let abort = handler.get_abort();
|
||||
init_tokio_runtime()?.block_on(async {
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.config().read().dry_run {
|
||||
let content = self.config().read().echo_messages(content);
|
||||
let tokens = tokenize(&content);
|
||||
for token in tokens {
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
handler.text(&token)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = self.config().read().prepare_send_data(content, true)?;
|
||||
self.send_message_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done()?;
|
||||
ret.with_context(|| "Failed to fetch stream")
|
||||
}
|
||||
_ = watch_abort(abort.clone()) => {
|
||||
handler.done()?;
|
||||
Ok(())
|
||||
},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
abort.set_ctrlc();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
data: SendData,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ExtraConfig {
|
||||
pub proxy: Option<String>,
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
|
||||
OpenAIClient::init(config.clone())
|
||||
.or_else(|| LocalAIClient::init(config.clone()))
|
||||
.or_else(|| AzureOpenAIClient::init(config.clone()))
|
||||
.ok_or_else(|| {
|
||||
let model_info = config.read().model_info.clone();
|
||||
anyhow!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_client_types() -> Vec<&'static str> {
|
||||
vec![
|
||||
OpenAIClient::NAME,
|
||||
LocalAIClient::NAME,
|
||||
AzureOpenAIClient::NAME,
|
||||
]
|
||||
}
|
||||
|
||||
pub fn create_client_config(client: &str) -> Result<String> {
|
||||
if client == OpenAIClient::NAME {
|
||||
OpenAIClient::create_config()
|
||||
} else if client == LocalAIClient::NAME {
|
||||
LocalAIClient::create_config()
|
||||
} else if client == AzureOpenAIClient::NAME {
|
||||
AzureOpenAIClient::create_config()
|
||||
} else {
|
||||
bail!("Unknown client {}", &client)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_models(config: &Config) -> Vec<ModelInfo> {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, v)| match v {
|
||||
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i),
|
||||
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i),
|
||||
ClientConfig::AzureOpenAI(c) => AzureOpenAIClient::list_models(c, i),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn init_tokio_runtime() -> Result<tokio::runtime::Runtime> {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.with_context(|| "Failed to init tokio")
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_api_base() -> Result<String> {
|
||||
Text::new("API Base:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_api_key() -> Result<String> {
|
||||
Text::new("API Key:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_api_key_optional() -> Result<String> {
|
||||
Text::new("API Key:").prompt().map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_model_name() -> Result<String> {
|
||||
Text::new("Model Name:")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_input_max_token() -> Result<String> {
|
||||
Text::new("Max tokens:")
|
||||
.with_default("4096")
|
||||
.with_validator(required!("This field is required"))
|
||||
.prompt()
|
||||
.map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_op_err<T>(_: T) -> anyhow::Error {
|
||||
anyhow!("An error happened, try again later.")
|
||||
}
|
||||
|
||||
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
|
||||
let proxy = if let Some(proxy) = proxy {
|
||||
if proxy.is_empty() || proxy == "false" || proxy == "-" {
|
||||
return Ok(builder);
|
||||
}
|
||||
proxy.clone()
|
||||
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
|
||||
proxy
|
||||
} else {
|
||||
return Ok(builder);
|
||||
};
|
||||
let builder =
|
||||
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
Ok(builder)
|
||||
}
|
||||
use serde_json::Value;
|
||||
|
||||
register_role!(
|
||||
("openai", OpenAI, OpenAIConfig, OpenAIClient),
|
||||
("localai", LocalAI, LocalAIConfig, LocalAIClient),
|
||||
(
|
||||
"azure-openai",
|
||||
AzureOpenAI,
|
||||
AzureOpenAIConfig,
|
||||
AzureOpenAIClient
|
||||
),
|
||||
);
|
||||
|
@ -0,0 +1,27 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
pub client: String,
|
||||
pub name: String,
|
||||
pub max_tokens: Option<usize>,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl Default for ModelInfo {
|
||||
fn default() -> Self {
|
||||
ModelInfo::new("", "", None, 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
pub fn new(client: &str, name: &str, max_tokens: Option<usize>, index: usize) -> Self {
|
||||
Self {
|
||||
client: client.into(),
|
||||
name: name.into(),
|
||||
max_tokens,
|
||||
index,
|
||||
}
|
||||
}
|
||||
pub fn stringify(&self) -> String {
|
||||
format!("{}:{}", self.client, self.name)
|
||||
}
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
use inquire::{required, validator::Validation, Text};
|
||||
|
||||
const MSG_REQUIRED: &str = "This field is required";
|
||||
const MSG_OPTIONAL: &str = "Optional field - Press ↵ to skip";
|
||||
|
||||
pub fn prompt_input_string(desc: &str, required: bool) -> anyhow::Result<String> {
|
||||
let mut text = Text::new(desc);
|
||||
if required {
|
||||
text = text.with_validator(required!(MSG_REQUIRED))
|
||||
} else {
|
||||
text = text.with_help_message(MSG_OPTIONAL)
|
||||
}
|
||||
text.prompt().map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub fn prompt_input_integer(desc: &str, required: bool) -> anyhow::Result<String> {
|
||||
let mut text = Text::new(desc);
|
||||
if required {
|
||||
text = text.with_validator(|text: &str| {
|
||||
let out = if text.is_empty() {
|
||||
Validation::Invalid(MSG_REQUIRED.into())
|
||||
} else {
|
||||
validate_integer(text)
|
||||
};
|
||||
Ok(out)
|
||||
})
|
||||
} else {
|
||||
text = text
|
||||
.with_validator(|text: &str| {
|
||||
let out = if text.is_empty() {
|
||||
Validation::Valid
|
||||
} else {
|
||||
validate_integer(text)
|
||||
};
|
||||
Ok(out)
|
||||
})
|
||||
.with_help_message(MSG_OPTIONAL)
|
||||
}
|
||||
text.prompt().map_err(prompt_op_err)
|
||||
}
|
||||
|
||||
pub fn prompt_op_err<T>(_: T) -> anyhow::Error {
|
||||
anyhow::anyhow!("Not finish questionnaire, try again later!")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PromptKind {
|
||||
String,
|
||||
Integer,
|
||||
}
|
||||
|
||||
fn validate_integer(text: &str) -> Validation {
|
||||
if text.parse::<i32>().is_err() {
|
||||
Validation::Invalid("Must be a integer".into())
|
||||
} else {
|
||||
Validation::Valid
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue