refactor: split into separate mods

This commit is contained in:
sigoden 2023-03-03 17:53:29 +08:00
parent 2e511c1327
commit 3ffebce8bb
5 changed files with 615 additions and 456 deletions

29
src/cli.rs Normal file
View File

@ -0,0 +1,29 @@
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
/// List all roles
#[clap(short = 'L', long)]
pub list_roles: bool,
/// Specify the role that the AI will play
#[clap(short, long)]
pub role: Option<String>,
/// Input text, if no input text, enter interactive mode
text: Vec<String>,
}
impl Cli {
pub fn text(&self) -> Option<String> {
let text = self
.text
.iter()
.map(|x| x.trim().to_string())
.collect::<Vec<String>>()
.join(" ");
if text.is_empty() {
return None;
}
Some(text)
}
}

166
src/client.rs Normal file
View File

@ -0,0 +1,166 @@
use crate::config::Config;
use anyhow::{anyhow, Result};
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use reqwest::{Client, Proxy, RequestBuilder};
use serde_json::{json, Value};
use std::sync::atomic::{AtomicBool, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::runtime::Runtime;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
const MODEL: &str = "gpt-3.5-turbo";
#[derive(Debug)]
pub struct ChatGptClient {
client: Client,
config: Arc<Config>,
runtime: Runtime,
}
impl ChatGptClient {
pub fn init(config: Arc<Config>) -> Result<Self> {
let mut builder = Client::builder();
if let Some(proxy) = config.proxy.as_ref() {
builder = builder
.proxy(Proxy::all(proxy).map_err(|err| anyhow!("Invalid config.proxy, {err}"))?);
}
let client = builder
.connect_timeout(CONNECT_TIMEOUT)
.build()
.map_err(|err| anyhow!("Failed to init http client, {err}"))?;
let runtime = init_runtime()?;
Ok(Self {
client,
config,
runtime,
})
}
pub fn acquire(&self, input: &str, prompt: Option<String>) -> Result<String> {
self.runtime
.block_on(async { self.acquire_inner(input, prompt).await })
}
pub fn acquire_stream<T>(
&self,
input: &str,
prompt: Option<String>,
output: &mut String,
handler: T,
ctrlc: Arc<AtomicBool>,
) -> Result<()>
where
T: FnOnce(&mut String, &str) + Copy,
{
self.runtime.block_on(async {
tokio::select! {
ret = self.acquire_stream_inner(input, prompt, handler, output) => {
ret
}
_ = tokio::signal::ctrl_c() => {
ctrlc.store(true, Ordering::SeqCst);
Ok(())
}
}
})
}
async fn acquire_inner(&self, content: &str, prompt: Option<String>) -> Result<String> {
let content = combine(content, prompt);
if self.config.dry_run {
return Ok(content);
}
let builder = self.request_builder(&content, false);
let data: Value = builder.send().await?.json().await?;
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
Ok(output.to_string())
}
async fn acquire_stream_inner<T>(
&self,
content: &str,
prompt: Option<String>,
handler: T,
output: &mut String,
) -> Result<()>
where
T: FnOnce(&mut String, &str) + Copy,
{
let content = combine(content, prompt);
if self.config.dry_run {
handler(output, &content);
return Ok(());
}
let builder = self.request_builder(&content, true);
let mut stream = builder.send().await?.bytes_stream().eventsource();
let mut virgin = true;
while let Some(part) = stream.next().await {
let chunk = part?.data;
if chunk == "[DONE]" {
break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
if virgin {
virgin = false;
if text == "\n\n" {
continue;
}
}
handler(output, text);
}
}
Ok(())
}
fn request_builder(&self, content: &str, stream: bool) -> RequestBuilder {
let mut body = json!({
"model": MODEL,
"messages": [{"role": "user", "content": content}],
});
if let Some(v) = self.config.temperature {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
self.client
.post(API_URL)
.bearer_auth(&self.config.api_key)
.json(&body)
}
}
fn combine(content: &str, prompt: Option<String>) -> String {
match prompt {
Some(v) => format!("{v} {content}"),
None => content.to_string(),
}
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| anyhow!("Failed to init tokio, {err}"))
}

View File

@ -1,10 +1,13 @@
use std::{
env,
fs::{self, read_to_string},
fs::{create_dir_all, read_to_string, File, OpenOptions},
io::Write,
path::{Path, PathBuf},
process::exit,
};
use anyhow::{anyhow, Result};
use inquire::{Confirm, Text};
use serde::Deserialize;
const CONFIG_FILE_NAME: &str = "config.yaml";
@ -32,15 +35,23 @@ pub struct Config {
}
impl Config {
pub fn init(path: &Path) -> Result<Config> {
let content = read_to_string(path)
.map_err(|err| anyhow!("Failed to load config at {}, {err}", path.display()))?;
pub fn init(is_interactive: bool) -> Result<Config> {
let config_path = Config::config_file()?;
if is_interactive && !config_path.exists() {
create_config_file(&config_path)?;
}
let content = read_to_string(&config_path)
.map_err(|err| anyhow!("Failed to load config at {}, {err}", config_path.display()))?;
let mut config: Config =
serde_yaml::from_str(&content).map_err(|err| anyhow!("Invalid config, {err}"))?;
config.load_roles()?;
Ok(config)
}
pub fn find_role(&self, name: &str) -> Option<Role> {
self.roles.iter().find(|v| v.name == name).cloned()
}
pub fn local_file(name: &str) -> Result<PathBuf> {
let env_name = format!(
"{}_CONFIG_DIR",
@ -52,7 +63,7 @@ impl Config {
};
path.push(env!("CARGO_CRATE_NAME"));
if !path.exists() {
fs::create_dir_all(&path).map_err(|err| {
create_dir_all(&path).map_err(|err| {
anyhow!("Failed to create config dir at {}, {err}", path.display())
})?;
}
@ -60,6 +71,37 @@ impl Config {
Ok(path)
}
pub fn open_message_file(&self) -> Result<Option<File>> {
if !self.save {
return Ok(None);
}
let path = Config::messages_file()?;
let file: Option<File> = if self.save {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|err| anyhow!("Failed to create/append {}, {err}", path.display()))?;
Some(file)
} else {
None
};
Ok(file)
}
pub fn save_message(file: Option<&mut File>, input: &str, output: &str) {
if let (false, Some(file)) = (output.is_empty(), file) {
let _ = file.write_all(
format!(
"AICHAT: {}\n\n--------\n{}\n--------\n\n",
input.trim(),
output.trim(),
)
.as_bytes(),
);
}
}
pub fn config_file() -> Result<PathBuf> {
Self::local_file(CONFIG_FILE_NAME)
}
@ -72,7 +114,7 @@ impl Config {
Self::local_file(HISTORY_FILE_NAME)
}
pub fn messages_file() -> Result<PathBuf> {
fn messages_file() -> Result<PathBuf> {
Self::local_file(MESSAGE_FILE_NAME)
}
@ -98,8 +140,39 @@ pub struct Role {
pub prompt: String,
}
impl Role {
pub fn generate(&self, text: &str) -> String {
format!("{} {}", self.prompt, text)
fn create_config_file(config_path: &Path) -> Result<()> {
let confirm_map_err = |_| anyhow!("Error with questionnaire, try again later");
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()
.map_err(confirm_map_err)?;
if !ans {
exit(0);
}
let api_key = Text::new("Openai API Key:")
.prompt()
.map_err(text_map_err)?;
let mut raw_config = format!("api_key: {api_key}\n");
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(confirm_map_err)?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?;
raw_config.push_str(&format!("proxy: {proxy}\n"));
}
let ans = Confirm::new("Save chat messages")
.with_default(false)
.prompt()
.map_err(confirm_map_err)?;
if ans {
raw_config.push_str("save: true\n");
}
std::fs::write(config_path, raw_config)
.map_err(|err| anyhow!("Failed to write to config file, {err}"))?;
Ok(())
}

View File

@ -1,39 +1,18 @@
mod cli;
mod client;
mod config;
mod repl;
use std::fs::{File, OpenOptions};
use std::io::{stdout, Write};
use std::path::Path;
use std::process::exit;
use std::time::Duration;
use std::sync::Arc;
use cli::Cli;
use client::ChatGptClient;
use config::{Config, Role};
use anyhow::{anyhow, Result};
use clap::{Arg, ArgAction, Command};
use eventsource_stream::{EventStream, Eventsource};
use futures_util::Stream;
use futures_util::StreamExt;
use inquire::{Confirm, Editor, Text};
use reedline::{
default_emacs_keybindings, ColumnarMenu, DefaultCompleter, DefaultPrompt, DefaultPromptSegment,
Emacs, FileBackedHistory, KeyCode, KeyModifiers, Reedline, ReedlineEvent, ReedlineMenu, Signal,
};
use reqwest::{Client, Proxy};
use serde_json::{json, Value};
use tokio::runtime::Runtime;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
const MODEL: &str = "gpt-3.5-turbo";
const REPL_COMMANDS: [(&str, &str); 7] = [
(".clear", "Clear the screen"),
(".clear-history", "Clear the history"),
(".exit", "Exit the REPL"),
(".help", "Print this help message"),
(".history", "Print the history"),
(".role", "Specify the role that the AI will play"),
(".view", "Use an external editor to view the AI reply"),
];
use clap::Parser;
use repl::{Repl, ReplCmdHandler};
fn main() {
if let Err(err) = start() {
@ -43,430 +22,43 @@ fn main() {
}
fn start() -> Result<()> {
let matches = Command::new(env!("CARGO_CRATE_NAME"))
.version(env!("CARGO_PKG_VERSION"))
.author(env!("CARGO_PKG_AUTHORS"))
.about(concat!(
env!("CARGO_PKG_DESCRIPTION"),
" - ",
env!("CARGO_PKG_REPOSITORY")
))
.arg(
Arg::new("list-roles")
.short('L')
.long("list-roles")
.action(ArgAction::SetTrue)
.help("List all roles"),
)
.arg(
Arg::new("role")
.short('r')
.long("role")
.action(ArgAction::Set)
.help("Specify the role that the AI will play"),
)
.arg(
Arg::new("text")
.action(ArgAction::Append)
.help("Input text"),
)
.get_matches();
let mut text = matches.get_many::<String>("text").map(|v| {
v.map(|x| x.trim().to_string())
.collect::<Vec<String>>()
.join(" ")
});
let config_path = Config::config_file()?;
if !config_path.exists() && text.is_none() {
create_config_file(&config_path)?;
}
let config = Config::init(&config_path)?;
let role_name = matches.get_one::<String>("role").cloned();
if let (Some(name), Some(text_)) = (role_name.as_ref(), text.as_ref()) {
let role = config
.roles
.iter()
.find(|v| &v.name == name)
.ok_or_else(|| anyhow!("Unknown role \"{name}\" "))?;
text = Some(role.generate(text_));
};
if matches.get_flag("list-roles") {
let cli = Cli::parse();
let text = cli.text();
let config = Arc::new(Config::init(text.is_none())?);
if cli.list_roles {
config.roles.iter().for_each(|v| println!("{}", v.name));
exit(1);
}
let client = init_client(&config)?;
let runtime = init_runtime()?;
match text {
Some(text) => {
let output = runtime.block_on(async move { acquire(&client, &config, &text).await })?;
println!("{}", output.trim());
}
None => run_repl(runtime, client, config, role_name)?,
}
Ok(())
}
fn run_repl(
runtime: Runtime,
client: Client,
config: Config,
role_name: Option<String>,
) -> Result<()> {
print_repl_title();
let mut commands: Vec<String> = REPL_COMMANDS
.into_iter()
.map(|(v, _)| v.to_string())
.collect();
commands.extend(config.roles.iter().map(|v| format!(".role {}", v.name)));
let mut completer = DefaultCompleter::with_inclusions(&['.', '-']).set_min_word_len(2);
completer.insert(commands.clone());
let completer = Box::new(completer);
let completion_menu = Box::new(ColumnarMenu::default().with_name("completion_menu"));
let mut keybindings = default_emacs_keybindings();
keybindings.add_binding(
KeyModifiers::NONE,
KeyCode::Tab,
ReedlineEvent::UntilFound(vec![
ReedlineEvent::Menu("completion_menu".to_string()),
ReedlineEvent::MenuNext,
]),
);
let history = Box::new(
FileBackedHistory::with_file(1000, Config::history_file()?)
.map_err(|err| anyhow!("Failed to setup history file, {err}"))?,
);
let edit_mode = Box::new(Emacs::new(keybindings));
let mut line_editor = Reedline::create()
.with_completer(completer)
.with_history(history)
.with_menu(ReedlineMenu::EngineCompleter(completion_menu))
.with_edit_mode(edit_mode);
let prompt = DefaultPrompt::new(DefaultPromptSegment::Empty, DefaultPromptSegment::Empty);
let mut trigged_ctrlc = false;
let mut output = String::new();
let mut role: Option<Role> = None;
let mut save_file: Option<File> = if config.save {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(Config::messages_file()?)
.map_err(|err| anyhow!("Failed to create/append save_file, {err}"))?;
Some(file)
} else {
None
};
let handle_line = |line: String,
line_editor: &mut Reedline,
trigged_ctrlc: &mut bool,
role: &mut Option<Role>,
output: &mut String,
save_file: &mut Option<File>|
-> Result<bool> {
if line.starts_with('.') {
let (name, args) = match line.split_once(' ') {
Some((head, tail)) => (head, Some(tail.trim())),
None => (line.as_str(), None),
};
match name {
".view" => {
if output.is_empty() {
return Ok(false);
}
let _ = Editor::new("view ai reply with an external editor")
.with_file_extension(".md")
.with_predefined_text(output)
.prompt()?;
dump("", 1);
}
".exit" => {
return Ok(true);
}
".help" => {
dump(get_repl_help(), 2);
}
".clear" => {
line_editor.clear_scrollback()?;
}
".clear-history" => {
let history = Box::new(line_editor.history_mut());
history
.clear()
.map_err(|err| anyhow!("Failed to clear history, {err}"))?;
}
".history" => {
line_editor.print_history()?;
dump("", 1);
}
".role" => match args {
Some(name) => match config.roles.iter().find(|v| v.name == name) {
Some(role_) => {
*role = Some(role_.clone());
dump("", 1);
}
None => dump("Unknown role.", 2),
},
None => dump("Usage: .role <name>.", 2),
},
_ => unknown_command(),
}
} else {
let input = if let Some(role) = role.take() {
role.generate(&line)
} else {
line
};
output.clear();
*trigged_ctrlc = false;
if input.is_empty() {
return Ok(false);
}
runtime.block_on(async {
tokio::select! {
ret = handle_input(&client, &config, &input, output) => {
if let Err(err) = ret {
dump(format!("error: {err}"), 2);
}
}
_ = tokio::signal::ctrl_c() => {
*trigged_ctrlc = true;
dump(" Abort current session.", 2)
}
}
});
if !output.is_empty() {
if let Some(file) = save_file.as_mut() {
let _ = file.write_all(
format!("AICHAT: {input}\n\n--------\n{output}\n--------\n\n").as_bytes(),
);
}
}
}
Ok(false)
};
if let Some(name) = role_name {
handle_line(
format!(".role {name}"),
&mut line_editor,
&mut trigged_ctrlc,
&mut role,
&mut output,
&mut save_file,
)?;
}
loop {
let sig = line_editor.read_line(&prompt);
match sig {
Ok(Signal::Success(line)) => {
let quit = handle_line(
line,
&mut line_editor,
&mut trigged_ctrlc,
&mut role,
&mut output,
&mut save_file,
)?;
if quit {
break;
}
}
Ok(Signal::CtrlC) => {
if !trigged_ctrlc {
trigged_ctrlc = true;
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
} else {
break;
}
}
Ok(Signal::CtrlD) => {
break;
}
Err(err) => {
eprintln!("{err:?}");
break;
}
}
}
Ok(())
}
async fn handle_input(
client: &Client,
config: &Config,
input: &str,
output: &mut String,
) -> Result<()> {
if config.dry_run {
output.push_str(input);
dump(input, 2);
return Ok(());
}
let mut stream = acquire_stream(client, config, input).await?;
let mut virgin = true;
while let Some(part) = stream.next().await {
let chunk = part?.data;
if chunk == "[DONE]" {
output.push('\n');
dump("", 2);
break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
if virgin {
virgin = false;
if text == "\n\n" {
continue;
}
}
output.push_str(text);
dump(text, 0);
}
}
Ok(())
}
fn init_client(config: &Config) -> Result<Client> {
let mut builder = Client::builder();
if let Some(proxy) = config.proxy.as_ref() {
builder =
builder.proxy(Proxy::all(proxy).map_err(|err| anyhow!("Invalid config.proxy, {err}"))?);
}
let client = builder
.connect_timeout(CONNECT_TIMEOUT)
.build()
.map_err(|err| anyhow!("Failed to init http client, {err}"))?;
Ok(client)
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| anyhow!("Failed to init tokio, {err}"))
}
fn create_config_file(config_path: &Path) -> Result<()> {
let confirm_map_err = |_| anyhow!("Error with questionnaire, try again later");
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()
.map_err(confirm_map_err)?;
if !ans {
exit(0);
}
let api_key = Text::new("Openai API Key:")
.prompt()
.map_err(text_map_err)?;
let mut raw_config = format!("api_key: {api_key}\n");
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(confirm_map_err)?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?;
raw_config.push_str(&format!("proxy: {proxy}\n"));
let role = match &cli.role {
Some(name) => Some(
config
.find_role(name)
.ok_or_else(|| anyhow!("Uknown role '{name}'"))?,
),
None => None,
};
let client = ChatGptClient::init(config.clone())?;
match text {
Some(text) => start_directive(client, config, role, &text),
None => start_interactive(client, config, role),
}
}
let ans = Confirm::new("Save chat messages")
.with_default(false)
.prompt()
.map_err(confirm_map_err)?;
if ans {
raw_config.push_str("save: true\n");
}
std::fs::write(config_path, raw_config)
.map_err(|err| anyhow!("Failed to write to config file, {err}"))?;
fn start_directive(
client: ChatGptClient,
config: Arc<Config>,
role: Option<Role>,
input: &str,
) -> Result<()> {
let mut file = config.open_message_file()?;
let output = client.acquire(input, role.map(|v| v.prompt))?;
println!("{}", output.trim());
Config::save_message(file.as_mut(), input, &output);
Ok(())
}
async fn acquire(client: &Client, config: &Config, content: &str) -> Result<String> {
if config.dry_run {
return Ok(content.to_string());
}
let mut body = json!({
"model": MODEL,
"messages": [{"role": "user", "content": content}]
});
if let Some(v) = config.temperature {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
let data: Value = client
.post(API_URL)
.bearer_auth(&config.api_key)
.json(&body)
.send()
.await?
.json()
.await?;
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
Ok(output.to_string())
}
async fn acquire_stream(
client: &Client,
config: &Config,
content: &str,
) -> Result<EventStream<impl Stream<Item = reqwest::Result<bytes::Bytes>>>> {
let mut body = json!({
"model": MODEL,
"messages": [{"role": "user", "content": content}],
"stream": true,
});
if let Some(v) = config.temperature {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
let stream = client
.post(API_URL)
.bearer_auth(&config.api_key)
.json(&body)
.send()
.await?
.bytes_stream()
.eventsource();
Ok(stream)
}
fn unknown_command() {
dump("Unknown command. Type \".help\" for more information.", 2);
}
fn dump<T: ToString>(text: T, newlines: usize) {
print!("{}{}", text.to_string(), "\n".repeat(newlines));
stdout().flush().unwrap();
}
fn print_repl_title() {
println!("Welcome to aichat {}", env!("CARGO_PKG_VERSION"));
println!("Type \".help\" for more information.");
}
fn get_repl_help() -> String {
let head = REPL_COMMANDS
.iter()
.map(|(name, desc)| format!("{name:<15} {desc}"))
.collect::<Vec<String>>()
.join("\n");
format!("{head}\n\nPress Ctrl+C to abort session, Ctrl+D to exit the REPL")
fn start_interactive(client: ChatGptClient, config: Arc<Config>, role: Option<Role>) -> Result<()> {
let mut repl = Repl::init(config.clone())?;
let handler = ReplCmdHandler::init(client, config, role)?;
repl.run(handler)
}

299
src/repl.rs Normal file
View File

@ -0,0 +1,299 @@
use crate::client::ChatGptClient;
use crate::config::{Config, Role};
use anyhow::{anyhow, Result};
use inquire::Editor;
use reedline::{
default_emacs_keybindings, ColumnarMenu, DefaultCompleter, DefaultPrompt, DefaultPromptSegment,
Emacs, FileBackedHistory, KeyCode, KeyModifiers, Keybindings, Reedline, ReedlineEvent,
ReedlineMenu, Signal,
};
use std::cell::RefCell;
use std::fs::File;
use std::io::{stdout, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
const REPL_COMMANDS: [(&str, &str); 8] = [
(".clear", "Clear the screen"),
(".clear-history", "Clear the history"),
(".clear-role", "Clear the role status"),
(".exit", "Exit the REPL"),
(".help", "Print this help message"),
(".history", "Print the history"),
(".role", "Specify the role that the AI will play"),
(".view", "Use an external editor to view the AI reply"),
];
const MENU_NAME: &str = "completion_menu";
pub struct Repl {
editor: Reedline,
prompt: DefaultPrompt,
}
impl Repl {
pub fn init(config: Arc<Config>) -> Result<Self> {
let completer = Self::create_completer(config);
let keybindings = Self::create_keybindings();
let history = Self::create_history()?;
let menu = Self::create_menu();
let edit_mode = Box::new(Emacs::new(keybindings));
let editor = Reedline::create()
.with_completer(Box::new(completer))
.with_history(history)
.with_menu(menu)
.with_edit_mode(edit_mode);
let prompt = Self::create_prompt();
Ok(Self { editor, prompt })
}
pub fn run(&mut self, handler: ReplCmdHandler) -> Result<()> {
dump(
format!("Welcome to aichat {}", env!("CARGO_PKG_VERSION")),
1,
);
dump("Type \".help\" for more information.", 1);
let mut current_ctrlc = false;
let handler = Arc::new(handler);
loop {
if handler.ctrlc.load(Ordering::SeqCst) {
handler.ctrlc.store(false, Ordering::SeqCst);
current_ctrlc = true
}
let sig = self.editor.read_line(&self.prompt);
match sig {
Ok(Signal::Success(line)) => {
current_ctrlc = false;
match self.handle_line(handler.clone(), line) {
Ok(quit) => {
if quit {
break;
}
}
Err(err) => {
dump(format!("{err:?}"), 1);
}
}
}
Ok(Signal::CtrlC) => {
if !current_ctrlc {
current_ctrlc = true;
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
} else {
break;
}
}
Ok(Signal::CtrlD) => {
break;
}
Err(err) => {
dump(format!("{err:?}"), 1);
break;
}
}
}
// tx.send(ReplCmd::Quit).unwrap();
Ok(())
}
fn handle_line(&mut self, handler: Arc<ReplCmdHandler>, line: String) -> Result<bool> {
if line.starts_with('.') {
let (cmd, args) = match line.split_once(' ') {
Some((head, tail)) => (head, Some(tail.trim())),
None => (line.as_str(), None),
};
match cmd {
".view" => handler.handle(ReplCmd::View)?,
".exit" => {
return Ok(true);
}
".help" => {
dump_repl_help();
}
".clear" => {
self.editor.clear_scrollback()?;
}
".clear-history" => {
let history = Box::new(self.editor.history_mut());
history
.clear()
.map_err(|err| anyhow!("Failed to clear history, {err}"))?;
dump("", 1);
}
".history" => {
self.editor.print_history()?;
dump("", 1);
}
".role" => match args {
Some(name) => handler.handle(ReplCmd::SetRole(name.to_string()))?,
None => dump("Usage: .role <name>", 2),
},
".clear-role" => {
handler.handle(ReplCmd::UnsetRole)?;
dump("", 1);
}
_ => dump_unknown_command(),
}
} else {
handler.handle(ReplCmd::Input(line))?;
}
Ok(false)
}
fn create_prompt() -> DefaultPrompt {
DefaultPrompt::new(DefaultPromptSegment::Empty, DefaultPromptSegment::Empty)
}
fn create_completer(config: Arc<Config>) -> DefaultCompleter {
let mut commands: Vec<String> = REPL_COMMANDS
.into_iter()
.map(|(v, _)| v.to_string())
.collect();
commands.extend(config.roles.iter().map(|v| format!(".role {}", v.name)));
let mut completer = DefaultCompleter::with_inclusions(&['.', '-']).set_min_word_len(2);
completer.insert(commands.clone());
completer
}
fn create_keybindings() -> Keybindings {
let mut keybindings = default_emacs_keybindings();
keybindings.add_binding(
KeyModifiers::NONE,
KeyCode::Tab,
ReedlineEvent::UntilFound(vec![
ReedlineEvent::Menu(MENU_NAME.to_string()),
ReedlineEvent::MenuNext,
]),
);
keybindings
}
fn create_menu() -> ReedlineMenu {
let completion_menu = ColumnarMenu::default().with_name(MENU_NAME);
ReedlineMenu::EngineCompleter(Box::new(completion_menu))
}
fn create_history() -> Result<Box<FileBackedHistory>> {
Ok(Box::new(
FileBackedHistory::with_file(1000, Config::history_file()?)
.map_err(|err| anyhow!("Failed to setup history file, {err}"))?,
))
}
}
pub struct ReplCmdHandler {
client: ChatGptClient,
config: Arc<Config>,
state: RefCell<ReplCmdHandlerState>,
ctrlc: Arc<AtomicBool>,
}
struct ReplCmdHandlerState {
prompt: String,
output: String,
save_file: Option<File>,
}
impl ReplCmdHandler {
pub fn init(client: ChatGptClient, config: Arc<Config>, role: Option<Role>) -> Result<Self> {
let prompt = role.map(|v| v.prompt).unwrap_or_default();
let save_file = config.open_message_file()?;
let ctrlc = Arc::new(AtomicBool::new(false));
let state = RefCell::new(ReplCmdHandlerState {
prompt,
save_file,
output: String::new(),
});
Ok(Self {
client,
config,
ctrlc,
state,
})
}
fn handle(&self, cmd: ReplCmd) -> Result<()> {
match cmd {
ReplCmd::Input(input) => {
let mut output = String::new();
if input.is_empty() {
self.state.borrow_mut().output.clear();
return Ok(());
}
let prompt = self.state.borrow().prompt.to_string();
let prompt = if prompt.is_empty() {
None
} else {
Some(prompt)
};
self.client.acquire_stream(
&input,
prompt,
&mut output,
dump_and_collect,
self.ctrlc.clone(),
)?;
dump_and_collect(&mut output, "\n\n");
Config::save_message(self.state.borrow_mut().save_file.as_mut(), &input, &output);
self.state.borrow_mut().output = output;
}
ReplCmd::View => {
let output = self.state.borrow().output.to_string();
if output.is_empty() {
return Ok(());
}
let _ = Editor::new("view ai reply with an external editor")
.with_file_extension(".md")
.with_predefined_text(&output)
.prompt()?;
dump("", 1);
}
ReplCmd::SetRole(name) => match self.config.find_role(&name) {
Some(v) => {
self.state.borrow_mut().prompt = v.prompt;
dump("", 1);
}
None => {
dump("Unknown role", 2);
}
},
ReplCmd::UnsetRole => {
self.state.borrow_mut().prompt = String::new();
}
}
Ok(())
}
}
pub enum ReplCmd {
View,
UnsetRole,
Input(String),
SetRole(String),
}
pub fn dump<T: ToString>(text: T, newlines: usize) {
print!("{}{}", text.to_string(), "\n".repeat(newlines));
stdout().flush().unwrap();
}
fn dump_and_collect(output: &mut String, reply: &str) {
output.push_str(reply);
dump(reply, 0);
}
fn dump_repl_help() {
let head = REPL_COMMANDS
.iter()
.map(|(name, desc)| format!("{name:<15} {desc}"))
.collect::<Vec<String>>()
.join("\n");
dump(
format!("{head}\n\nPress Ctrl+C to abort session, Ctrl+D to exit the REPL"),
2,
);
}
fn dump_unknown_command() {
dump("Unknown command. Type \".help\" for more information.", 2);
}