mirror of
https://github.com/sigoden/aichat
synced 2024-11-16 06:15:26 +00:00
refactor: split into separate mods
This commit is contained in:
parent
2e511c1327
commit
3ffebce8bb
29
src/cli.rs
Normal file
29
src/cli.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Cli {
|
||||
/// List all roles
|
||||
#[clap(short = 'L', long)]
|
||||
pub list_roles: bool,
|
||||
/// Specify the role that the AI will play
|
||||
#[clap(short, long)]
|
||||
pub role: Option<String>,
|
||||
/// Input text, if no input text, enter interactive mode
|
||||
text: Vec<String>,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn text(&self) -> Option<String> {
|
||||
let text = self
|
||||
.text
|
||||
.iter()
|
||||
.map(|x| x.trim().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(text)
|
||||
}
|
||||
}
|
166
src/client.rs
Normal file
166
src/client.rs
Normal file
@ -0,0 +1,166 @@
|
||||
use crate::config::Config;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::{Client, Proxy, RequestBuilder};
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
|
||||
const MODEL: &str = "gpt-3.5-turbo";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChatGptClient {
|
||||
client: Client,
|
||||
config: Arc<Config>,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl ChatGptClient {
|
||||
pub fn init(config: Arc<Config>) -> Result<Self> {
|
||||
let mut builder = Client::builder();
|
||||
if let Some(proxy) = config.proxy.as_ref() {
|
||||
builder = builder
|
||||
.proxy(Proxy::all(proxy).map_err(|err| anyhow!("Invalid config.proxy, {err}"))?);
|
||||
}
|
||||
let client = builder
|
||||
.connect_timeout(CONNECT_TIMEOUT)
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to init http client, {err}"))?;
|
||||
|
||||
let runtime = init_runtime()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
config,
|
||||
runtime,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn acquire(&self, input: &str, prompt: Option<String>) -> Result<String> {
|
||||
self.runtime
|
||||
.block_on(async { self.acquire_inner(input, prompt).await })
|
||||
}
|
||||
|
||||
pub fn acquire_stream<T>(
|
||||
&self,
|
||||
input: &str,
|
||||
prompt: Option<String>,
|
||||
output: &mut String,
|
||||
handler: T,
|
||||
ctrlc: Arc<AtomicBool>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: FnOnce(&mut String, &str) + Copy,
|
||||
{
|
||||
self.runtime.block_on(async {
|
||||
tokio::select! {
|
||||
ret = self.acquire_stream_inner(input, prompt, handler, output) => {
|
||||
ret
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
ctrlc.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn acquire_inner(&self, content: &str, prompt: Option<String>) -> Result<String> {
|
||||
let content = combine(content, prompt);
|
||||
if self.config.dry_run {
|
||||
return Ok(content);
|
||||
}
|
||||
let builder = self.request_builder(&content, false);
|
||||
|
||||
let data: Value = builder.send().await?.json().await?;
|
||||
|
||||
let output = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
|
||||
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
async fn acquire_stream_inner<T>(
|
||||
&self,
|
||||
content: &str,
|
||||
prompt: Option<String>,
|
||||
handler: T,
|
||||
output: &mut String,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: FnOnce(&mut String, &str) + Copy,
|
||||
{
|
||||
let content = combine(content, prompt);
|
||||
if self.config.dry_run {
|
||||
handler(output, &content);
|
||||
return Ok(());
|
||||
}
|
||||
let builder = self.request_builder(&content, true);
|
||||
let mut stream = builder.send().await?.bytes_stream().eventsource();
|
||||
let mut virgin = true;
|
||||
while let Some(part) = stream.next().await {
|
||||
let chunk = part?.data;
|
||||
if chunk == "[DONE]" {
|
||||
break;
|
||||
} else {
|
||||
let data: Value = serde_json::from_str(&chunk)?;
|
||||
let text = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if virgin {
|
||||
virgin = false;
|
||||
if text == "\n\n" {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
handler(output, text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> RequestBuilder {
|
||||
let mut body = json!({
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
});
|
||||
|
||||
if let Some(v) = self.config.temperature {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
self.client
|
||||
.post(API_URL)
|
||||
.bearer_auth(&self.config.api_key)
|
||||
.json(&body)
|
||||
}
|
||||
}
|
||||
|
||||
fn combine(content: &str, prompt: Option<String>) -> String {
|
||||
match prompt {
|
||||
Some(v) => format!("{v} {content}"),
|
||||
None => content.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn init_runtime() -> Result<Runtime> {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to init tokio, {err}"))
|
||||
}
|
@ -1,10 +1,13 @@
|
||||
use std::{
|
||||
env,
|
||||
fs::{self, read_to_string},
|
||||
fs::{create_dir_all, read_to_string, File, OpenOptions},
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
process::exit,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use inquire::{Confirm, Text};
|
||||
use serde::Deserialize;
|
||||
|
||||
const CONFIG_FILE_NAME: &str = "config.yaml";
|
||||
@ -32,15 +35,23 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn init(path: &Path) -> Result<Config> {
|
||||
let content = read_to_string(path)
|
||||
.map_err(|err| anyhow!("Failed to load config at {}, {err}", path.display()))?;
|
||||
pub fn init(is_interactive: bool) -> Result<Config> {
|
||||
let config_path = Config::config_file()?;
|
||||
if is_interactive && !config_path.exists() {
|
||||
create_config_file(&config_path)?;
|
||||
}
|
||||
let content = read_to_string(&config_path)
|
||||
.map_err(|err| anyhow!("Failed to load config at {}, {err}", config_path.display()))?;
|
||||
let mut config: Config =
|
||||
serde_yaml::from_str(&content).map_err(|err| anyhow!("Invalid config, {err}"))?;
|
||||
config.load_roles()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn find_role(&self, name: &str) -> Option<Role> {
|
||||
self.roles.iter().find(|v| v.name == name).cloned()
|
||||
}
|
||||
|
||||
pub fn local_file(name: &str) -> Result<PathBuf> {
|
||||
let env_name = format!(
|
||||
"{}_CONFIG_DIR",
|
||||
@ -52,7 +63,7 @@ impl Config {
|
||||
};
|
||||
path.push(env!("CARGO_CRATE_NAME"));
|
||||
if !path.exists() {
|
||||
fs::create_dir_all(&path).map_err(|err| {
|
||||
create_dir_all(&path).map_err(|err| {
|
||||
anyhow!("Failed to create config dir at {}, {err}", path.display())
|
||||
})?;
|
||||
}
|
||||
@ -60,6 +71,37 @@ impl Config {
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn open_message_file(&self) -> Result<Option<File>> {
|
||||
if !self.save {
|
||||
return Ok(None);
|
||||
}
|
||||
let path = Config::messages_file()?;
|
||||
let file: Option<File> = if self.save {
|
||||
let file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.map_err(|err| anyhow!("Failed to create/append {}, {err}", path.display()))?;
|
||||
Some(file)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
pub fn save_message(file: Option<&mut File>, input: &str, output: &str) {
|
||||
if let (false, Some(file)) = (output.is_empty(), file) {
|
||||
let _ = file.write_all(
|
||||
format!(
|
||||
"AICHAT: {}\n\n--------\n{}\n--------\n\n",
|
||||
input.trim(),
|
||||
output.trim(),
|
||||
)
|
||||
.as_bytes(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config_file() -> Result<PathBuf> {
|
||||
Self::local_file(CONFIG_FILE_NAME)
|
||||
}
|
||||
@ -72,7 +114,7 @@ impl Config {
|
||||
Self::local_file(HISTORY_FILE_NAME)
|
||||
}
|
||||
|
||||
pub fn messages_file() -> Result<PathBuf> {
|
||||
fn messages_file() -> Result<PathBuf> {
|
||||
Self::local_file(MESSAGE_FILE_NAME)
|
||||
}
|
||||
|
||||
@ -98,8 +140,39 @@ pub struct Role {
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn generate(&self, text: &str) -> String {
|
||||
format!("{} {}", self.prompt, text)
|
||||
fn create_config_file(config_path: &Path) -> Result<()> {
|
||||
let confirm_map_err = |_| anyhow!("Error with questionnaire, try again later");
|
||||
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
|
||||
let ans = Confirm::new("No config file, create a new one?")
|
||||
.with_default(true)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if !ans {
|
||||
exit(0);
|
||||
}
|
||||
let api_key = Text::new("Openai API Key:")
|
||||
.prompt()
|
||||
.map_err(text_map_err)?;
|
||||
let mut raw_config = format!("api_key: {api_key}\n");
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?;
|
||||
raw_config.push_str(&format!("proxy: {proxy}\n"));
|
||||
}
|
||||
|
||||
let ans = Confirm::new("Save chat messages")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if ans {
|
||||
raw_config.push_str("save: true\n");
|
||||
}
|
||||
|
||||
std::fs::write(config_path, raw_config)
|
||||
.map_err(|err| anyhow!("Failed to write to config file, {err}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
486
src/main.rs
486
src/main.rs
@ -1,39 +1,18 @@
|
||||
mod cli;
|
||||
mod client;
|
||||
mod config;
|
||||
mod repl;
|
||||
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{stdout, Write};
|
||||
use std::path::Path;
|
||||
use std::process::exit;
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cli::Cli;
|
||||
use client::ChatGptClient;
|
||||
use config::{Config, Role};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use clap::{Arg, ArgAction, Command};
|
||||
use eventsource_stream::{EventStream, Eventsource};
|
||||
use futures_util::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use inquire::{Confirm, Editor, Text};
|
||||
use reedline::{
|
||||
default_emacs_keybindings, ColumnarMenu, DefaultCompleter, DefaultPrompt, DefaultPromptSegment,
|
||||
Emacs, FileBackedHistory, KeyCode, KeyModifiers, Reedline, ReedlineEvent, ReedlineMenu, Signal,
|
||||
};
|
||||
use reqwest::{Client, Proxy};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
|
||||
const MODEL: &str = "gpt-3.5-turbo";
|
||||
const REPL_COMMANDS: [(&str, &str); 7] = [
|
||||
(".clear", "Clear the screen"),
|
||||
(".clear-history", "Clear the history"),
|
||||
(".exit", "Exit the REPL"),
|
||||
(".help", "Print this help message"),
|
||||
(".history", "Print the history"),
|
||||
(".role", "Specify the role that the AI will play"),
|
||||
(".view", "Use an external editor to view the AI reply"),
|
||||
];
|
||||
use clap::Parser;
|
||||
use repl::{Repl, ReplCmdHandler};
|
||||
|
||||
fn main() {
|
||||
if let Err(err) = start() {
|
||||
@ -43,430 +22,43 @@ fn main() {
|
||||
}
|
||||
|
||||
fn start() -> Result<()> {
|
||||
let matches = Command::new(env!("CARGO_CRATE_NAME"))
|
||||
.version(env!("CARGO_PKG_VERSION"))
|
||||
.author(env!("CARGO_PKG_AUTHORS"))
|
||||
.about(concat!(
|
||||
env!("CARGO_PKG_DESCRIPTION"),
|
||||
" - ",
|
||||
env!("CARGO_PKG_REPOSITORY")
|
||||
))
|
||||
.arg(
|
||||
Arg::new("list-roles")
|
||||
.short('L')
|
||||
.long("list-roles")
|
||||
.action(ArgAction::SetTrue)
|
||||
.help("List all roles"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("role")
|
||||
.short('r')
|
||||
.long("role")
|
||||
.action(ArgAction::Set)
|
||||
.help("Specify the role that the AI will play"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("text")
|
||||
.action(ArgAction::Append)
|
||||
.help("Input text"),
|
||||
)
|
||||
.get_matches();
|
||||
let mut text = matches.get_many::<String>("text").map(|v| {
|
||||
v.map(|x| x.trim().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
});
|
||||
let config_path = Config::config_file()?;
|
||||
if !config_path.exists() && text.is_none() {
|
||||
create_config_file(&config_path)?;
|
||||
}
|
||||
let config = Config::init(&config_path)?;
|
||||
|
||||
let role_name = matches.get_one::<String>("role").cloned();
|
||||
if let (Some(name), Some(text_)) = (role_name.as_ref(), text.as_ref()) {
|
||||
let role = config
|
||||
.roles
|
||||
.iter()
|
||||
.find(|v| &v.name == name)
|
||||
.ok_or_else(|| anyhow!("Unknown role \"{name}\" "))?;
|
||||
text = Some(role.generate(text_));
|
||||
};
|
||||
|
||||
if matches.get_flag("list-roles") {
|
||||
let cli = Cli::parse();
|
||||
let text = cli.text();
|
||||
let config = Arc::new(Config::init(text.is_none())?);
|
||||
if cli.list_roles {
|
||||
config.roles.iter().for_each(|v| println!("{}", v.name));
|
||||
exit(1);
|
||||
}
|
||||
|
||||
let client = init_client(&config)?;
|
||||
let runtime = init_runtime()?;
|
||||
match text {
|
||||
Some(text) => {
|
||||
let output = runtime.block_on(async move { acquire(&client, &config, &text).await })?;
|
||||
println!("{}", output.trim());
|
||||
}
|
||||
None => run_repl(runtime, client, config, role_name)?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_repl(
|
||||
runtime: Runtime,
|
||||
client: Client,
|
||||
config: Config,
|
||||
role_name: Option<String>,
|
||||
) -> Result<()> {
|
||||
print_repl_title();
|
||||
let mut commands: Vec<String> = REPL_COMMANDS
|
||||
.into_iter()
|
||||
.map(|(v, _)| v.to_string())
|
||||
.collect();
|
||||
commands.extend(config.roles.iter().map(|v| format!(".role {}", v.name)));
|
||||
let mut completer = DefaultCompleter::with_inclusions(&['.', '-']).set_min_word_len(2);
|
||||
completer.insert(commands.clone());
|
||||
let completer = Box::new(completer);
|
||||
let completion_menu = Box::new(ColumnarMenu::default().with_name("completion_menu"));
|
||||
let mut keybindings = default_emacs_keybindings();
|
||||
keybindings.add_binding(
|
||||
KeyModifiers::NONE,
|
||||
KeyCode::Tab,
|
||||
ReedlineEvent::UntilFound(vec![
|
||||
ReedlineEvent::Menu("completion_menu".to_string()),
|
||||
ReedlineEvent::MenuNext,
|
||||
]),
|
||||
);
|
||||
let history = Box::new(
|
||||
FileBackedHistory::with_file(1000, Config::history_file()?)
|
||||
.map_err(|err| anyhow!("Failed to setup history file, {err}"))?,
|
||||
);
|
||||
let edit_mode = Box::new(Emacs::new(keybindings));
|
||||
let mut line_editor = Reedline::create()
|
||||
.with_completer(completer)
|
||||
.with_history(history)
|
||||
.with_menu(ReedlineMenu::EngineCompleter(completion_menu))
|
||||
.with_edit_mode(edit_mode);
|
||||
let prompt = DefaultPrompt::new(DefaultPromptSegment::Empty, DefaultPromptSegment::Empty);
|
||||
let mut trigged_ctrlc = false;
|
||||
let mut output = String::new();
|
||||
let mut role: Option<Role> = None;
|
||||
let mut save_file: Option<File> = if config.save {
|
||||
let file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(Config::messages_file()?)
|
||||
.map_err(|err| anyhow!("Failed to create/append save_file, {err}"))?;
|
||||
Some(file)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let handle_line = |line: String,
|
||||
line_editor: &mut Reedline,
|
||||
trigged_ctrlc: &mut bool,
|
||||
role: &mut Option<Role>,
|
||||
output: &mut String,
|
||||
save_file: &mut Option<File>|
|
||||
-> Result<bool> {
|
||||
if line.starts_with('.') {
|
||||
let (name, args) = match line.split_once(' ') {
|
||||
Some((head, tail)) => (head, Some(tail.trim())),
|
||||
None => (line.as_str(), None),
|
||||
};
|
||||
match name {
|
||||
".view" => {
|
||||
if output.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
let _ = Editor::new("view ai reply with an external editor")
|
||||
.with_file_extension(".md")
|
||||
.with_predefined_text(output)
|
||||
.prompt()?;
|
||||
dump("", 1);
|
||||
}
|
||||
".exit" => {
|
||||
return Ok(true);
|
||||
}
|
||||
".help" => {
|
||||
dump(get_repl_help(), 2);
|
||||
}
|
||||
".clear" => {
|
||||
line_editor.clear_scrollback()?;
|
||||
}
|
||||
".clear-history" => {
|
||||
let history = Box::new(line_editor.history_mut());
|
||||
history
|
||||
.clear()
|
||||
.map_err(|err| anyhow!("Failed to clear history, {err}"))?;
|
||||
}
|
||||
".history" => {
|
||||
line_editor.print_history()?;
|
||||
dump("", 1);
|
||||
}
|
||||
".role" => match args {
|
||||
Some(name) => match config.roles.iter().find(|v| v.name == name) {
|
||||
Some(role_) => {
|
||||
*role = Some(role_.clone());
|
||||
dump("", 1);
|
||||
}
|
||||
None => dump("Unknown role.", 2),
|
||||
},
|
||||
None => dump("Usage: .role <name>.", 2),
|
||||
},
|
||||
_ => unknown_command(),
|
||||
}
|
||||
} else {
|
||||
let input = if let Some(role) = role.take() {
|
||||
role.generate(&line)
|
||||
} else {
|
||||
line
|
||||
};
|
||||
output.clear();
|
||||
*trigged_ctrlc = false;
|
||||
if input.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
runtime.block_on(async {
|
||||
tokio::select! {
|
||||
ret = handle_input(&client, &config, &input, output) => {
|
||||
if let Err(err) = ret {
|
||||
dump(format!("error: {err}"), 2);
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
*trigged_ctrlc = true;
|
||||
dump(" Abort current session.", 2)
|
||||
}
|
||||
}
|
||||
});
|
||||
if !output.is_empty() {
|
||||
if let Some(file) = save_file.as_mut() {
|
||||
let _ = file.write_all(
|
||||
format!("AICHAT: {input}\n\n--------\n{output}\n--------\n\n").as_bytes(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
if let Some(name) = role_name {
|
||||
handle_line(
|
||||
format!(".role {name}"),
|
||||
&mut line_editor,
|
||||
&mut trigged_ctrlc,
|
||||
&mut role,
|
||||
&mut output,
|
||||
&mut save_file,
|
||||
)?;
|
||||
}
|
||||
loop {
|
||||
let sig = line_editor.read_line(&prompt);
|
||||
match sig {
|
||||
Ok(Signal::Success(line)) => {
|
||||
let quit = handle_line(
|
||||
line,
|
||||
&mut line_editor,
|
||||
&mut trigged_ctrlc,
|
||||
&mut role,
|
||||
&mut output,
|
||||
&mut save_file,
|
||||
)?;
|
||||
if quit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Signal::CtrlC) => {
|
||||
if !trigged_ctrlc {
|
||||
trigged_ctrlc = true;
|
||||
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Signal::CtrlD) => {
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("{err:?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_input(
|
||||
client: &Client,
|
||||
config: &Config,
|
||||
input: &str,
|
||||
output: &mut String,
|
||||
) -> Result<()> {
|
||||
if config.dry_run {
|
||||
output.push_str(input);
|
||||
dump(input, 2);
|
||||
return Ok(());
|
||||
}
|
||||
let mut stream = acquire_stream(client, config, input).await?;
|
||||
let mut virgin = true;
|
||||
while let Some(part) = stream.next().await {
|
||||
let chunk = part?.data;
|
||||
if chunk == "[DONE]" {
|
||||
output.push('\n');
|
||||
dump("", 2);
|
||||
break;
|
||||
} else {
|
||||
let data: Value = serde_json::from_str(&chunk)?;
|
||||
let text = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if virgin {
|
||||
virgin = false;
|
||||
if text == "\n\n" {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
output.push_str(text);
|
||||
dump(text, 0);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_client(config: &Config) -> Result<Client> {
|
||||
let mut builder = Client::builder();
|
||||
if let Some(proxy) = config.proxy.as_ref() {
|
||||
builder =
|
||||
builder.proxy(Proxy::all(proxy).map_err(|err| anyhow!("Invalid config.proxy, {err}"))?);
|
||||
}
|
||||
let client = builder
|
||||
.connect_timeout(CONNECT_TIMEOUT)
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to init http client, {err}"))?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn init_runtime() -> Result<Runtime> {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to init tokio, {err}"))
|
||||
}
|
||||
|
||||
fn create_config_file(config_path: &Path) -> Result<()> {
|
||||
let confirm_map_err = |_| anyhow!("Error with questionnaire, try again later");
|
||||
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
|
||||
let ans = Confirm::new("No config file, create a new one?")
|
||||
.with_default(true)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if !ans {
|
||||
exit(0);
|
||||
}
|
||||
let api_key = Text::new("Openai API Key:")
|
||||
.prompt()
|
||||
.map_err(text_map_err)?;
|
||||
let mut raw_config = format!("api_key: {api_key}\n");
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?;
|
||||
raw_config.push_str(&format!("proxy: {proxy}\n"));
|
||||
let role = match &cli.role {
|
||||
Some(name) => Some(
|
||||
config
|
||||
.find_role(name)
|
||||
.ok_or_else(|| anyhow!("Uknown role '{name}'"))?,
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
let client = ChatGptClient::init(config.clone())?;
|
||||
match text {
|
||||
Some(text) => start_directive(client, config, role, &text),
|
||||
None => start_interactive(client, config, role),
|
||||
}
|
||||
}
|
||||
|
||||
let ans = Confirm::new("Save chat messages")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(confirm_map_err)?;
|
||||
if ans {
|
||||
raw_config.push_str("save: true\n");
|
||||
}
|
||||
|
||||
std::fs::write(config_path, raw_config)
|
||||
.map_err(|err| anyhow!("Failed to write to config file, {err}"))?;
|
||||
fn start_directive(
|
||||
client: ChatGptClient,
|
||||
config: Arc<Config>,
|
||||
role: Option<Role>,
|
||||
input: &str,
|
||||
) -> Result<()> {
|
||||
let mut file = config.open_message_file()?;
|
||||
let output = client.acquire(input, role.map(|v| v.prompt))?;
|
||||
println!("{}", output.trim());
|
||||
Config::save_message(file.as_mut(), input, &output);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn acquire(client: &Client, config: &Config, content: &str) -> Result<String> {
|
||||
if config.dry_run {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
let mut body = json!({
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": content}]
|
||||
});
|
||||
|
||||
if let Some(v) = config.temperature {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
let data: Value = client
|
||||
.post(API_URL)
|
||||
.bearer_auth(&config.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let output = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
|
||||
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
async fn acquire_stream(
|
||||
client: &Client,
|
||||
config: &Config,
|
||||
content: &str,
|
||||
) -> Result<EventStream<impl Stream<Item = reqwest::Result<bytes::Bytes>>>> {
|
||||
let mut body = json!({
|
||||
"model": MODEL,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
if let Some(v) = config.temperature {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
let stream = client
|
||||
.post(API_URL)
|
||||
.bearer_auth(&config.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?
|
||||
.bytes_stream()
|
||||
.eventsource();
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn unknown_command() {
|
||||
dump("Unknown command. Type \".help\" for more information.", 2);
|
||||
}
|
||||
|
||||
fn dump<T: ToString>(text: T, newlines: usize) {
|
||||
print!("{}{}", text.to_string(), "\n".repeat(newlines));
|
||||
stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn print_repl_title() {
|
||||
println!("Welcome to aichat {}", env!("CARGO_PKG_VERSION"));
|
||||
println!("Type \".help\" for more information.");
|
||||
}
|
||||
|
||||
fn get_repl_help() -> String {
|
||||
let head = REPL_COMMANDS
|
||||
.iter()
|
||||
.map(|(name, desc)| format!("{name:<15} {desc}"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
format!("{head}\n\nPress Ctrl+C to abort session, Ctrl+D to exit the REPL")
|
||||
fn start_interactive(client: ChatGptClient, config: Arc<Config>, role: Option<Role>) -> Result<()> {
|
||||
let mut repl = Repl::init(config.clone())?;
|
||||
let handler = ReplCmdHandler::init(client, config, role)?;
|
||||
repl.run(handler)
|
||||
}
|
||||
|
299
src/repl.rs
Normal file
299
src/repl.rs
Normal file
@ -0,0 +1,299 @@
|
||||
use crate::client::ChatGptClient;
|
||||
use crate::config::{Config, Role};
|
||||
use anyhow::{anyhow, Result};
|
||||
use inquire::Editor;
|
||||
use reedline::{
|
||||
default_emacs_keybindings, ColumnarMenu, DefaultCompleter, DefaultPrompt, DefaultPromptSegment,
|
||||
Emacs, FileBackedHistory, KeyCode, KeyModifiers, Keybindings, Reedline, ReedlineEvent,
|
||||
ReedlineMenu, Signal,
|
||||
};
|
||||
use std::cell::RefCell;
|
||||
use std::fs::File;
|
||||
use std::io::{stdout, Write};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
const REPL_COMMANDS: [(&str, &str); 8] = [
|
||||
(".clear", "Clear the screen"),
|
||||
(".clear-history", "Clear the history"),
|
||||
(".clear-role", "Clear the role status"),
|
||||
(".exit", "Exit the REPL"),
|
||||
(".help", "Print this help message"),
|
||||
(".history", "Print the history"),
|
||||
(".role", "Specify the role that the AI will play"),
|
||||
(".view", "Use an external editor to view the AI reply"),
|
||||
];
|
||||
|
||||
const MENU_NAME: &str = "completion_menu";
|
||||
|
||||
pub struct Repl {
|
||||
editor: Reedline,
|
||||
prompt: DefaultPrompt,
|
||||
}
|
||||
|
||||
impl Repl {
|
||||
pub fn init(config: Arc<Config>) -> Result<Self> {
|
||||
let completer = Self::create_completer(config);
|
||||
let keybindings = Self::create_keybindings();
|
||||
let history = Self::create_history()?;
|
||||
let menu = Self::create_menu();
|
||||
let edit_mode = Box::new(Emacs::new(keybindings));
|
||||
let editor = Reedline::create()
|
||||
.with_completer(Box::new(completer))
|
||||
.with_history(history)
|
||||
.with_menu(menu)
|
||||
.with_edit_mode(edit_mode);
|
||||
let prompt = Self::create_prompt();
|
||||
Ok(Self { editor, prompt })
|
||||
}
|
||||
|
||||
pub fn run(&mut self, handler: ReplCmdHandler) -> Result<()> {
|
||||
dump(
|
||||
format!("Welcome to aichat {}", env!("CARGO_PKG_VERSION")),
|
||||
1,
|
||||
);
|
||||
dump("Type \".help\" for more information.", 1);
|
||||
let mut current_ctrlc = false;
|
||||
let handler = Arc::new(handler);
|
||||
loop {
|
||||
if handler.ctrlc.load(Ordering::SeqCst) {
|
||||
handler.ctrlc.store(false, Ordering::SeqCst);
|
||||
current_ctrlc = true
|
||||
}
|
||||
let sig = self.editor.read_line(&self.prompt);
|
||||
match sig {
|
||||
Ok(Signal::Success(line)) => {
|
||||
current_ctrlc = false;
|
||||
match self.handle_line(handler.clone(), line) {
|
||||
Ok(quit) => {
|
||||
if quit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
dump(format!("{err:?}"), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Signal::CtrlC) => {
|
||||
if !current_ctrlc {
|
||||
current_ctrlc = true;
|
||||
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Signal::CtrlD) => {
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
dump(format!("{err:?}"), 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// tx.send(ReplCmd::Quit).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_line(&mut self, handler: Arc<ReplCmdHandler>, line: String) -> Result<bool> {
|
||||
if line.starts_with('.') {
|
||||
let (cmd, args) = match line.split_once(' ') {
|
||||
Some((head, tail)) => (head, Some(tail.trim())),
|
||||
None => (line.as_str(), None),
|
||||
};
|
||||
match cmd {
|
||||
".view" => handler.handle(ReplCmd::View)?,
|
||||
".exit" => {
|
||||
return Ok(true);
|
||||
}
|
||||
".help" => {
|
||||
dump_repl_help();
|
||||
}
|
||||
".clear" => {
|
||||
self.editor.clear_scrollback()?;
|
||||
}
|
||||
".clear-history" => {
|
||||
let history = Box::new(self.editor.history_mut());
|
||||
history
|
||||
.clear()
|
||||
.map_err(|err| anyhow!("Failed to clear history, {err}"))?;
|
||||
dump("", 1);
|
||||
}
|
||||
".history" => {
|
||||
self.editor.print_history()?;
|
||||
dump("", 1);
|
||||
}
|
||||
".role" => match args {
|
||||
Some(name) => handler.handle(ReplCmd::SetRole(name.to_string()))?,
|
||||
None => dump("Usage: .role <name>", 2),
|
||||
},
|
||||
".clear-role" => {
|
||||
handler.handle(ReplCmd::UnsetRole)?;
|
||||
dump("", 1);
|
||||
}
|
||||
_ => dump_unknown_command(),
|
||||
}
|
||||
} else {
|
||||
handler.handle(ReplCmd::Input(line))?;
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn create_prompt() -> DefaultPrompt {
|
||||
DefaultPrompt::new(DefaultPromptSegment::Empty, DefaultPromptSegment::Empty)
|
||||
}
|
||||
|
||||
fn create_completer(config: Arc<Config>) -> DefaultCompleter {
|
||||
let mut commands: Vec<String> = REPL_COMMANDS
|
||||
.into_iter()
|
||||
.map(|(v, _)| v.to_string())
|
||||
.collect();
|
||||
commands.extend(config.roles.iter().map(|v| format!(".role {}", v.name)));
|
||||
let mut completer = DefaultCompleter::with_inclusions(&['.', '-']).set_min_word_len(2);
|
||||
completer.insert(commands.clone());
|
||||
completer
|
||||
}
|
||||
|
||||
fn create_keybindings() -> Keybindings {
|
||||
let mut keybindings = default_emacs_keybindings();
|
||||
keybindings.add_binding(
|
||||
KeyModifiers::NONE,
|
||||
KeyCode::Tab,
|
||||
ReedlineEvent::UntilFound(vec![
|
||||
ReedlineEvent::Menu(MENU_NAME.to_string()),
|
||||
ReedlineEvent::MenuNext,
|
||||
]),
|
||||
);
|
||||
keybindings
|
||||
}
|
||||
|
||||
fn create_menu() -> ReedlineMenu {
|
||||
let completion_menu = ColumnarMenu::default().with_name(MENU_NAME);
|
||||
ReedlineMenu::EngineCompleter(Box::new(completion_menu))
|
||||
}
|
||||
|
||||
fn create_history() -> Result<Box<FileBackedHistory>> {
|
||||
Ok(Box::new(
|
||||
FileBackedHistory::with_file(1000, Config::history_file()?)
|
||||
.map_err(|err| anyhow!("Failed to setup history file, {err}"))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReplCmdHandler {
|
||||
client: ChatGptClient,
|
||||
config: Arc<Config>,
|
||||
state: RefCell<ReplCmdHandlerState>,
|
||||
ctrlc: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
struct ReplCmdHandlerState {
|
||||
prompt: String,
|
||||
output: String,
|
||||
save_file: Option<File>,
|
||||
}
|
||||
|
||||
impl ReplCmdHandler {
|
||||
pub fn init(client: ChatGptClient, config: Arc<Config>, role: Option<Role>) -> Result<Self> {
|
||||
let prompt = role.map(|v| v.prompt).unwrap_or_default();
|
||||
let save_file = config.open_message_file()?;
|
||||
let ctrlc = Arc::new(AtomicBool::new(false));
|
||||
let state = RefCell::new(ReplCmdHandlerState {
|
||||
prompt,
|
||||
save_file,
|
||||
output: String::new(),
|
||||
});
|
||||
Ok(Self {
|
||||
client,
|
||||
config,
|
||||
ctrlc,
|
||||
state,
|
||||
})
|
||||
}
|
||||
fn handle(&self, cmd: ReplCmd) -> Result<()> {
|
||||
match cmd {
|
||||
ReplCmd::Input(input) => {
|
||||
let mut output = String::new();
|
||||
if input.is_empty() {
|
||||
self.state.borrow_mut().output.clear();
|
||||
return Ok(());
|
||||
}
|
||||
let prompt = self.state.borrow().prompt.to_string();
|
||||
let prompt = if prompt.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(prompt)
|
||||
};
|
||||
self.client.acquire_stream(
|
||||
&input,
|
||||
prompt,
|
||||
&mut output,
|
||||
dump_and_collect,
|
||||
self.ctrlc.clone(),
|
||||
)?;
|
||||
dump_and_collect(&mut output, "\n\n");
|
||||
Config::save_message(self.state.borrow_mut().save_file.as_mut(), &input, &output);
|
||||
self.state.borrow_mut().output = output;
|
||||
}
|
||||
ReplCmd::View => {
|
||||
let output = self.state.borrow().output.to_string();
|
||||
if output.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let _ = Editor::new("view ai reply with an external editor")
|
||||
.with_file_extension(".md")
|
||||
.with_predefined_text(&output)
|
||||
.prompt()?;
|
||||
dump("", 1);
|
||||
}
|
||||
ReplCmd::SetRole(name) => match self.config.find_role(&name) {
|
||||
Some(v) => {
|
||||
self.state.borrow_mut().prompt = v.prompt;
|
||||
dump("", 1);
|
||||
}
|
||||
None => {
|
||||
dump("Unknown role", 2);
|
||||
}
|
||||
},
|
||||
ReplCmd::UnsetRole => {
|
||||
self.state.borrow_mut().prompt = String::new();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ReplCmd {
|
||||
View,
|
||||
UnsetRole,
|
||||
Input(String),
|
||||
SetRole(String),
|
||||
}
|
||||
|
||||
pub fn dump<T: ToString>(text: T, newlines: usize) {
|
||||
print!("{}{}", text.to_string(), "\n".repeat(newlines));
|
||||
stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn dump_and_collect(output: &mut String, reply: &str) {
|
||||
output.push_str(reply);
|
||||
dump(reply, 0);
|
||||
}
|
||||
|
||||
fn dump_repl_help() {
|
||||
let head = REPL_COMMANDS
|
||||
.iter()
|
||||
.map(|(name, desc)| format!("{name:<15} {desc}"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
dump(
|
||||
format!("{head}\n\nPress Ctrl+C to abort session, Ctrl+D to exit the REPL"),
|
||||
2,
|
||||
);
|
||||
}
|
||||
|
||||
fn dump_unknown_command() {
|
||||
dump("Unknown command. Type \".help\" for more information.", 2);
|
||||
}
|
Loading…
Reference in New Issue
Block a user