diff --git a/Cargo.lock b/Cargo.lock index 49825b4..b534392 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,7 @@ dependencies = [ "is-terminal", "lazy_static", "log", + "mime_guess", "nu-ansi-term", "parking_lot", "reedline", @@ -58,6 +59,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sha2", "shell-words", "simplelog", "syntect", @@ -246,6 +248,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bstr" version = "1.7.0" @@ -377,6 +388,15 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "cpufeatures" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.3.2" @@ -495,6 +515,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deranged" version = "0.3.9" @@ -504,6 +534,16 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dirs" version = "5.0.1" @@ -696,6 +736,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "gethostname" version = "0.2.3" @@ -883,9 +933,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", "hashbrown 0.14.2", @@ -1030,6 +1080,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1623,7 +1683,7 @@ version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ - "indexmap 2.0.2", + "indexmap 2.1.0", "itoa", "ryu", "serde", @@ -1647,13 +1707,24 @@ version = "0.9.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cc7a1570e38322cfe4154732e5110f887ea57e22b76f4bfd32b5bdd3368666c" dependencies = [ - "indexmap 2.0.2", + "indexmap 2.1.0", "itoa", "ryu", "serde", "unsafe-libyaml", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -2032,6 +2103,21 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -2100,6 +2186,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "vte" version = "0.10.1" diff --git a/Cargo.toml b/Cargo.toml index 09661ee..01f2db9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,8 @@ reqwest-eventsource = "0.5.0" simplelog = "0.12.1" log = "0.4.20" shell-words = "1.1.0" +mime_guess = "2.0.4" +sha2 = "0.10.8" [dependencies.reqwest] version = "0.11.14" diff --git a/README.md b/README.md index fe6f0dd..8a9173f 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases), - Support chat and command modes - Use [Roles](#roles) - Powerful [Chat REPL](#chat-repl) +- Support vision - Context-aware conversation/session - Syntax highlighting markdown and 200 other languages - Stream output with hand-typing effect @@ -147,9 +148,9 @@ The Chat REPL supports: .session Start a context-aware chat session .info session Show session info .exit session End the current session +.file Attach files to the message and then submit it .set Modify the configuration parameters .copy Copy the last reply to the clipboard -.read Read files into the message and submit .exit Exit the REPL Type ::: to begin multi-line editing, type ::: to end it. @@ -255,6 +256,17 @@ The prompt on the right side is about the current usage of tokens and the propor compared to the maximum number of tokens allowed by the model. +### `.file` - attach files to the message + +``` +Usage: .file ... [-- text...] + +.file message.txt +.file config.yaml -- convert to toml +.file a.jpg b.jpg -- What’s in these images? +.file https://ibb.co/a.png https://ibb.co/b.png -- what is the difference? +``` + ### `.set` - modify the configuration temporarily ``` @@ -277,6 +289,7 @@ Options: -m, --model Choose a LLM model -r, --role Choose a role -s, --session [] Create or reuse a session + -f, --file ... Attach files to the message to be sent -H, --no-highlight Disable syntax highlighting -S, --no-stream No stream output -w, --wrap Specify the text-wrapping mode (no*, auto, ) @@ -306,6 +319,9 @@ cat config.json | aichat convert to yaml # Read stdin cat config.json | aichat -r convert:yaml # Read stdin with a role cat config.json | aichat -s i18n # Read stdin with a session +aichat --file a.png b.png -- diff images # Attach files +aichat --file screenshot.png -r ocr # Attach files with a role + aichat --list-models # List all available models aichat --list-roles # List all available roles aichat --list-sessions # List all available models diff --git a/src/cli.rs b/src/cli.rs index 629266c..7e1229d 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -12,6 +12,9 @@ pub struct Cli { /// Create or reuse a session #[clap(short = 's', long)] pub session: Option>, + /// Attach files to the message to be sent. + #[clap(short = 'f', long, num_args = 1.., value_name = "FILE")] + pub file: Option>, /// Disable syntax highlighting #[clap(short = 'H', long)] pub no_highlight: bool, diff --git a/src/client/common.rs b/src/client/common.rs index cf5ba9b..2716d87 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,7 +1,7 @@ use super::{openai::OpenAIConfig, ClientConfig, Message}; use crate::{ - config::GlobalConfig, + config::{GlobalConfig, Input}, render::ReplyHandler, utils::{ init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, AbortSignal, @@ -50,7 +50,7 @@ macro_rules! register_client { } impl $client { - pub const NAME: &str = $name; + pub const NAME: &'static str = $name; pub fn init(global_config: &$crate::config::GlobalConfig) -> Option> { let model = global_config.read().model.clone(); @@ -186,22 +186,22 @@ pub trait Client { Ok(client) } - fn send_message(&self, content: &str) -> Result { + fn send_message(&self, input: Input) -> Result { init_tokio_runtime()?.block_on(async { let global_config = self.config().0; if global_config.read().dry_run { - let content = global_config.read().echo_messages(content); + let content = global_config.read().echo_messages(&input); return Ok(content); } let client = self.build_client()?; - let data = global_config.read().prepare_send_data(content, false)?; + let data = global_config.read().prepare_send_data(&input, false)?; self.send_message_inner(&client, data) .await .with_context(|| "Failed to get answer") }) } - fn send_message_streaming(&self, content: &str, handler: &mut ReplyHandler) -> Result<()> { + fn send_message_streaming(&self, input: &Input, handler: &mut ReplyHandler) -> Result<()> { async fn watch_abort(abort: AbortSignal) { loop { if abort.aborted() { @@ -211,12 +211,13 @@ pub trait Client { } } let abort = handler.get_abort(); - init_tokio_runtime()?.block_on(async { + let input = input.clone(); + init_tokio_runtime()?.block_on(async move { tokio::select! { ret = async { let global_config = self.config().0; if global_config.read().dry_run { - let content = global_config.read().echo_messages(content); + let content = global_config.read().echo_messages(&input); let tokens = tokenize(&content); for token in tokens { tokio::time::sleep(Duration::from_millis(10)).await; @@ -225,7 +226,7 @@ pub trait Client { return Ok(()); } let client = self.build_client()?; - let data = global_config.read().prepare_send_data(content, true)?; + let data = global_config.read().prepare_send_data(&input, true)?; self.send_message_streaming_inner(&client, handler, data).await } => { handler.done()?; diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 200433c..4bb3435 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,4 +1,4 @@ -use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model}; +use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model, MessageContent}; use crate::{ config::GlobalConfig, @@ -198,8 +198,10 @@ fn build_body(data: SendData, _model: String) -> Value { if messages[0].role.is_system() { let system_message = messages.remove(0); - if let Some(message) = messages.get_mut(0) { - message.content = format!("{}\n\n{}", system_message.content, message.content) + if let (Some(message), MessageContent::Text(system_text)) = (messages.get_mut(0), system_message.content) { + if let MessageContent::Text(text) = message.content.clone() { + message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text)) + } } } diff --git a/src/client/message.rs b/src/client/message.rs index 55b2663..dc8c3e1 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,16 +1,18 @@ +use crate::config::Input; + use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Message { pub role: MessageRole, - pub content: String, + pub content: MessageContent, } impl Message { - pub fn new(content: &str) -> Self { + pub fn new(input: &Input) -> Self { Self { role: MessageRole::User, - content: content.to_string(), + content: input.to_message_content(), } } } @@ -38,6 +40,65 @@ impl MessageRole { } } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Array(Vec), +} + +impl MessageContent { + pub fn render_input(&self, resolve_url_fn: impl Fn(&str) -> String) -> String { + match self { + MessageContent::Text(text) => text.to_string(), + MessageContent::Array(list) => { + let (mut concated_text, mut files) = (String::new(), vec![]); + for item in list { + match item { + MessageContentPart::Text { text } => { + concated_text = format!("{concated_text} {text}") + } + MessageContentPart::ImageUrl { image_url } => { + files.push(resolve_url_fn(&image_url.url)) + } + } + } + if !concated_text.is_empty() { + concated_text = format!(" -- {concated_text}") + } + format!(".file {}{}", files.join(" "), concated_text) + } + } + } + + pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) { + match self { + MessageContent::Text(text) => *text = replace_fn(text), + MessageContent::Array(list) => { + if list.is_empty() { + list.push(MessageContentPart::Text { + text: replace_fn(""), + }) + } else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) { + *text = replace_fn(text) + } + } + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessageContentPart { + Text { text: String }, + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, +} + #[cfg(test)] mod tests { use super::*; @@ -45,7 +106,7 @@ mod tests { #[test] fn test_serde() { assert_eq!( - serde_json::to_string(&Message::new("Hello World")).unwrap(), + serde_json::to_string(&Message::new(&Input::from_str("Hello World"))).unwrap(), "{\"role\":\"user\",\"content\":\"Hello World\"}" ); } diff --git a/src/client/model.rs b/src/client/model.rs index 16fe087..130489d 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -1,4 +1,4 @@ -use super::message::Message; +use super::message::{Message, MessageContent}; use crate::utils::count_tokens; @@ -79,7 +79,15 @@ impl Model { } pub fn messages_tokens(&self, messages: &[Message]) -> usize { - messages.iter().map(|v| count_tokens(&v.content)).sum() + messages + .iter() + .map(|v| { + match &v.content { + MessageContent::Text(text) => count_tokens(text), + MessageContent::Array(_) => 0, // TODO + } + }) + .sum() } pub fn total_tokens(&self, messages: &[Message]) -> usize { diff --git a/src/client/openai.rs b/src/client/openai.rs index dbe2661..928c6b3 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -19,13 +19,14 @@ use std::env; const API_BASE: &str = "https://api.openai.com/v1"; -const MODELS: [(&str, usize); 6] = [ +const MODELS: [(&str, usize); 7] = [ ("gpt-3.5-turbo", 4096), ("gpt-3.5-turbo-16k", 16385), ("gpt-3.5-turbo-1106", 16385), + ("gpt-4-1106-preview", 128000), + ("gpt-4-vision-preview", 128000), ("gpt-4", 8192), ("gpt-4-32k", 32768), - ("gpt-4-1106-preview", 128000), ]; pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); @@ -145,6 +146,12 @@ pub fn openai_build_body(data: SendData, model: String) -> Value { "model": model, "messages": messages, }); + + // The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger + if model == "gpt-4-vision-preview" { + body["max_tokens"] = json!(4096); + } + if let Some(v) = temperature { body["temperature"] = v.into(); } diff --git a/src/client/palm.rs b/src/client/palm.rs index a2aec8a..37ed0ae 100644 --- a/src/client/palm.rs +++ b/src/client/palm.rs @@ -1,4 +1,4 @@ -use super::{PaLMClient, Client, ExtraConfig, Model, PromptType, SendData, TokensCountFactors, send_message_as_streaming}; +use super::{PaLMClient, Client, ExtraConfig, Model, PromptType, SendData, TokensCountFactors, send_message_as_streaming, MessageContent}; use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; @@ -115,8 +115,10 @@ fn build_body(data: SendData, _model: String) -> Value { if messages[0].role.is_system() { let system_message = messages.remove(0); - if let Some(message) = messages.get_mut(0) { - message.content = format!("{}\n\n{}", system_message.content, message.content) + if let (Some(message), MessageContent::Text(system_text)) = (messages.get_mut(0), system_message.content) { + if let MessageContent::Text(text) = message.content.clone() { + message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text)) + } } } diff --git a/src/config/input.rs b/src/config/input.rs new file mode 100644 index 0000000..3997929 --- /dev/null +++ b/src/config/input.rs @@ -0,0 +1,162 @@ +use crate::client::{ImageUrl, MessageContent, MessageContentPart}; +use crate::utils::sha256sum; + +use anyhow::{bail, Context, Result}; +use base64::{self, engine::general_purpose::STANDARD, Engine}; +use mime_guess::from_path; +use std::{ + collections::HashMap, + fs::{self, File}, + io::Read, + path::{Path, PathBuf}, +}; + +const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"]; + +#[derive(Debug, Clone)] +pub struct Input { + text: String, + medias: Vec, + data_urls: HashMap, +} + +impl Input { + pub fn from_str(text: &str) -> Self { + Self { + text: text.to_string(), + medias: Default::default(), + data_urls: Default::default(), + } + } + + pub fn new(text: &str, files: Vec) -> Result { + let mut texts = vec![text.to_string()]; + let mut medias = vec![]; + let mut data_urls = HashMap::new(); + for file_item in files.into_iter() { + match resolve_path(&file_item) { + Some(file_path) => { + let file_path = fs::canonicalize(file_path) + .with_context(|| format!("Unable to use file '{file_item}"))?; + if is_image_ext(&file_path) { + let data_url = read_media_to_data_url(&file_path)?; + data_urls.insert(sha256sum(&data_url), file_path.display().to_string()); + medias.push(data_url) + } else { + let mut text = String::new(); + let mut file = File::open(&file_path) + .with_context(|| format!("Unable to open file '{file_item}'"))?; + file.read_to_string(&mut text) + .with_context(|| format!("Unable to read file '{file_item}'"))?; + texts.push(text); + } + } + None => { + if is_image_ext(Path::new(&file_item)) { + medias.push(file_item) + } else { + bail!("Unable to use file '{file_item}"); + } + } + } + } + + Ok(Self { + text: texts.join("\n"), + medias, + data_urls, + }) + } + + pub fn data_urls(&self) -> HashMap { + self.data_urls.clone() + } + + pub fn render(&self) -> String { + if self.medias.is_empty() { + return self.text.clone(); + } + let text = if self.text.is_empty() { + self.text.to_string() + } else { + format!(" -- {}", self.text) + }; + let files: Vec = self + .medias + .iter() + .cloned() + .map(|url| resolve_data_url(&self.data_urls, url)) + .collect(); + format!(".file {}{}", files.join(" "), text) + } + + pub fn to_message_content(&self) -> MessageContent { + if self.medias.is_empty() { + MessageContent::Text(self.text.clone()) + } else { + let mut list: Vec = self + .medias + .iter() + .cloned() + .map(|url| MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + }) + .collect(); + if !self.text.is_empty() { + list.insert( + 0, + MessageContentPart::Text { + text: self.text.clone(), + }, + ); + } + MessageContent::Array(list) + } + } +} + +pub fn resolve_data_url(data_urls: &HashMap, data_url: String) -> String { + if data_url.starts_with("data:") { + let hash = sha256sum(&data_url); + if let Some(path) = data_urls.get(&hash) { + return path.to_string(); + } + data_url + } else { + data_url + } +} + +fn resolve_path(file: &str) -> Option { + if ["https://", "http://", "data:"] + .iter() + .any(|v| file.starts_with(v)) + { + return None; + } + let path = if let (Some(file), Some(home)) = (file.strip_prefix('~'), dirs::home_dir()) { + home.join(file) + } else { + std::env::current_dir().ok()?.join(file) + }; + Some(path) +} + +fn is_image_ext(path: &Path) -> bool { + path.extension() + .map(|v| IMAGE_EXTS.iter().any(|ext| *ext == v.to_string_lossy())) + .unwrap_or_default() +} + +fn read_media_to_data_url>(image_path: P) -> Result { + let mime_type = from_path(&image_path).first_or_octet_stream().to_string(); + + let mut file = File::open(image_path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + + let encoded_image = STANDARD.encode(buffer); + let data_url = format!("data:{};base64,{}", mime_type, encoded_image); + + Ok(data_url) +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 9c86e4c..08509be 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,6 +1,8 @@ +mod input; mod role; mod session; +pub use self::input::Input; use self::role::Role; use self::session::{Session, TEMP_SESSION_NAME}; @@ -78,7 +80,7 @@ pub struct Config { #[serde(skip)] pub model: Model, #[serde(skip)] - pub last_message: Option<(String, String)>, + pub last_message: Option<(Input, String)>, #[serde(skip)] pub temperature: Option, } @@ -200,15 +202,15 @@ impl Config { Ok(path) } - pub fn save_message(&mut self, input: &str, output: &str) -> Result<()> { - self.last_message = Some((input.to_string(), output.to_string())); + pub fn save_message(&mut self, input: Input, output: &str) -> Result<()> { + self.last_message = Some((input.clone(), output.to_string())); if self.dry_run { return Ok(()); } if let Some(session) = self.session.as_mut() { - session.add_message(input, output)?; + session.add_message(&input, output)?; return Ok(()); } @@ -220,13 +222,14 @@ impl Config { return Ok(()); } let timestamp = now(); + let input_markdown = input.render(); let output = match self.role.as_ref() { None => { - format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",) + format!("# CHAT:[{timestamp}]\n{input_markdown}\n--------\n{output}\n--------\n\n",) } Some(v) => { format!( - "# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n", + "# CHAT:[{timestamp}] ({})\n{input_markdown}\n--------\n{output}\n--------\n\n", v.name, ) } @@ -292,23 +295,23 @@ impl Config { Ok(()) } - pub fn echo_messages(&self, content: &str) -> String { + pub fn echo_messages(&self, input: &Input) -> String { if let Some(session) = self.session.as_ref() { - session.echo_messages(content) + session.echo_messages(input) } else if let Some(role) = self.role.as_ref() { - role.echo_messages(content) + role.echo_messages(input) } else { - content.to_string() + input.render() } } - pub fn build_messages(&self, content: &str) -> Result> { + pub fn build_messages(&self, input: &Input) -> Result> { let messages = if let Some(session) = self.session.as_ref() { - session.build_emssages(content) + session.build_emssages(input) } else if let Some(role) = self.role.as_ref() { - role.build_messages(content) + role.build_messages(input) } else { - let message = Message::new(content); + let message = Message::new(input); vec![message] }; Ok(messages) @@ -586,7 +589,7 @@ impl Config { Ok(dir) => dir, Err(_) => return vec![], }; - match read_dir(&sessions_dir) { + match read_dir(sessions_dir) { Ok(rd) => { let mut names = vec![]; for entry in rd.flatten() { @@ -643,8 +646,8 @@ impl Config { } } - pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result { - let messages = self.build_messages(content)?; + pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result { + let messages = self.build_messages(input)?; self.model.max_tokens_limit(&messages)?; Ok(SendData { messages, @@ -653,7 +656,7 @@ impl Config { }) } - pub fn maybe_print_send_tokens(&self, input: &str) { + pub fn maybe_print_send_tokens(&self, input: &Input) { if self.dry_run { if let Ok(messages) = self.build_messages(input) { let tokens = self.model.total_tokens(&messages); diff --git a/src/config/role.rs b/src/config/role.rs index 2b8fea1..bd7216c 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -1,8 +1,10 @@ -use crate::client::{Message, MessageRole}; +use crate::client::{Message, MessageContent, MessageRole}; use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; +use super::Input; + const INPUT_PLACEHOLDER: &str = "__INPUT__"; #[derive(Debug, Clone, Deserialize, Serialize)] @@ -41,17 +43,20 @@ impl Role { } } - pub fn echo_messages(&self, content: &str) -> String { + pub fn echo_messages(&self, input: &Input) -> String { + let input_markdown = input.render(); if self.embedded() { - merge_prompt_content(&self.prompt, content) + self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown) } else { - format!("{}\n\n{content}", self.prompt) + format!("{}\n\n{}", self.prompt, input.render()) } } - pub fn build_messages(&self, content: &str) -> Vec { + pub fn build_messages(&self, input: &Input) -> Vec { + let mut content = input.to_message_content(); + if self.embedded() { - let content = merge_prompt_content(&self.prompt, content); + content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v)); vec![Message { role: MessageRole::User, content, @@ -60,21 +65,17 @@ impl Role { vec![ Message { role: MessageRole::System, - content: self.prompt.clone(), + content: MessageContent::Text(self.prompt.clone()), }, Message { role: MessageRole::User, - content: content.to_string(), + content, }, ] } } } -fn merge_prompt_content(prompt: &str, content: &str) -> String { - prompt.replace(INPUT_PLACEHOLDER, content) -} - fn complete_prompt_args(prompt: &str, name: &str) -> String { let mut prompt = prompt.trim().to_string(); for (i, arg) in name.split(':').skip(1).enumerate() { diff --git a/src/config/session.rs b/src/config/session.rs index 1aebd64..644255f 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -1,12 +1,14 @@ +use super::input::resolve_data_url; use super::role::Role; -use super::Model; +use super::{Input, Model}; -use crate::client::{Message, MessageRole}; +use crate::client::{Message, MessageContent, MessageRole}; use crate::render::MarkdownRender; use anyhow::{bail, Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::collections::HashMap; use std::fs::{self, read_to_string}; use std::path::Path; @@ -18,6 +20,7 @@ pub struct Session { model_id: String, temperature: Option, messages: Vec, + data_urls: HashMap, #[serde(skip)] pub name: String, #[serde(skip)] @@ -37,6 +40,7 @@ impl Session { model_id: model.id(), temperature, messages: vec![], + data_urls: Default::default(), name: name.to_string(), path: None, dirty: false, @@ -121,6 +125,7 @@ impl Session { if !self.is_empty() { lines.push("".into()); + let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string()); for message in &self.messages { match message.role { @@ -128,11 +133,17 @@ impl Session { continue; } MessageRole::Assistant => { - lines.push(render.render(&message.content)); + if let MessageContent::Text(text) = &message.content { + lines.push(render.render(text)); + } lines.push("".into()); } MessageRole::User => { - lines.push(format!("{}){}", self.name, message.content)); + lines.push(format!( + "{}){}", + self.name, + message.content.render_input(resolve_url_fn) + )); } } } @@ -218,7 +229,7 @@ impl Session { self.messages.is_empty() } - pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { + pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> { let mut need_add_msg = true; if self.messages.is_empty() { if let Some(role) = self.role.as_ref() { @@ -229,35 +240,36 @@ impl Session { if need_add_msg { self.messages.push(Message { role: MessageRole::User, - content: input.to_string(), + content: input.to_message_content(), }); } + self.data_urls.extend(input.data_urls()); self.messages.push(Message { role: MessageRole::Assistant, - content: output.to_string(), + content: MessageContent::Text(output.to_string()), }); self.dirty = true; Ok(()) } - pub fn echo_messages(&self, content: &str) -> String { - let messages = self.build_emssages(content); + pub fn echo_messages(&self, input: &Input) -> String { + let messages = self.build_emssages(input); serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) } - pub fn build_emssages(&self, content: &str) -> Vec { + pub fn build_emssages(&self, input: &Input) -> Vec { let mut messages = self.messages.clone(); let mut need_add_msg = true; if messages.is_empty() { if let Some(role) = self.role.as_ref() { - messages = role.build_messages(content); + messages = role.build_messages(input); need_add_msg = false; } }; if need_add_msg { messages.push(Message { role: MessageRole::User, - content: content.into(), + content: input.to_message_content(), }); } messages diff --git a/src/main.rs b/src/main.rs index bb6c8a0..4d199aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ use crate::config::{Config, GlobalConfig}; use anyhow::Result; use clap::Parser; use client::{init_client, list_models}; +use config::Input; use is_terminal::IsTerminal; use parking_lot::RwLock; use render::{render_error, render_stream, MarkdownRender}; @@ -75,18 +76,22 @@ fn main() -> Result<()> { return Ok(()); } config.write().onstart()?; - let no_stream = cli.no_stream; - if let Err(err) = start(&config, text, no_stream) { + if let Err(err) = start(&config, text, cli.file, cli.no_stream) { let highlight = stderr().is_terminal() && config.read().highlight; render_error(err, highlight) } Ok(()) } -fn start(config: &GlobalConfig, text: Option, no_stream: bool) -> Result<()> { +fn start( + config: &GlobalConfig, + text: Option, + include: Option>, + no_stream: bool, +) -> Result<()> { if stdin().is_terminal() { match text { - Some(text) => start_directive(config, &text, no_stream), + Some(text) => start_directive(config, &text, include, no_stream), None => start_interactive(config), } } else { @@ -95,18 +100,24 @@ fn start(config: &GlobalConfig, text: Option, no_stream: bool) -> Result if let Some(text) = text { input = format!("{text}\n{input}"); } - start_directive(config, &input, no_stream) + start_directive(config, &input, include, no_stream) } } -fn start_directive(config: &GlobalConfig, input: &str, no_stream: bool) -> Result<()> { +fn start_directive( + config: &GlobalConfig, + text: &str, + include: Option>, + no_stream: bool, +) -> Result<()> { if let Some(session) = &config.read().session { session.guard_save()?; } + let input = Input::new(text, include.unwrap_or_default())?; let client = init_client(config)?; - config.read().maybe_print_send_tokens(input); + config.read().maybe_print_send_tokens(&input); let output = if no_stream { - let output = client.send_message(input)?; + let output = client.send_message(input.clone())?; if stdout().is_terminal() { let render_options = config.read().get_render_options()?; let mut markdown_render = MarkdownRender::init(render_options)?; @@ -117,7 +128,7 @@ fn start_directive(config: &GlobalConfig, input: &str, no_stream: bool) -> Resul output } else { let abort = create_abort_signal(); - render_stream(input, client.as_ref(), config, abort)? + render_stream(&input, client.as_ref(), config, abort)? }; config.write().save_message(input, &output) } diff --git a/src/render/mod.rs b/src/render/mod.rs index 1e20ccc..dc4c081 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -5,7 +5,7 @@ pub use self::markdown::{MarkdownRender, RenderOptions}; use self::stream::{markdown_stream, raw_stream}; use crate::client::Client; -use crate::config::GlobalConfig; +use crate::config::{GlobalConfig, Input}; use crate::utils::AbortSignal; use anyhow::{Context, Result}; @@ -17,7 +17,7 @@ use std::io::stdout; use std::thread::spawn; pub fn render_stream( - input: &str, + input: &Input, client: &dyn Client, config: &GlobalConfig, abort: AbortSignal, diff --git a/src/render/stream.rs b/src/render/stream.rs index bc56db8..2c05119 100644 --- a/src/render/stream.rs +++ b/src/render/stream.rs @@ -167,7 +167,7 @@ struct Spinner { } impl Spinner { - const DATA: [&str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + const DATA: [&'static str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; fn new(message: &str) -> Self { Spinner { diff --git a/src/repl/mod.rs b/src/repl/mod.rs index df4d1c2..b8580c5 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -7,7 +7,7 @@ use self::highlighter::ReplHighlighter; use self::prompt::ReplPrompt; use crate::client::init_client; -use crate::config::GlobalConfig; +use crate::config::{GlobalConfig, Input}; use crate::render::{render_error, render_stream}; use crate::utils::{create_abort_signal, set_text, AbortSignal}; @@ -20,7 +20,6 @@ use reedline::{ ColumnarMenu, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline, ReedlineEvent, ReedlineMenu, ValidationResult, Validator, Vi, }; -use std::io::Read; const MENU_NAME: &str = "completion_menu"; @@ -34,9 +33,9 @@ const REPL_COMMANDS: [(&str, &str); 13] = [ (".session", "Start a context-aware chat session"), (".info session", "Show session info"), (".exit session", "End the current session"), + (".file", "Attach files to the message and then submit it"), (".set", "Modify the configuration parameters"), (".copy", "Copy the last reply to the clipboard"), - (".read", "Read files into the message and submit"), (".exit", "Exit the REPL"), ]; @@ -159,7 +158,7 @@ impl Repl { let old_role = self.config.read().role.as_ref().map(|v| v.name.to_string()); self.config.write().set_role(name)?; - self.ask(text)?; + self.ask(text, vec![])?; match old_role { Some(old_role) => self.config.write().set_role(&old_role)?, None => self.config.write().clear_role()?, @@ -184,29 +183,19 @@ impl Repl { self.copy(config.last_reply()) .with_context(|| "Failed to copy the last output")?; } - ".read" => match args { + ".read" => { + println!(r#"Deprecated. Use '.read' instead."#); + } + ".file" => match args { Some(args) => { let (files, text) = match args.split_once(" -- ") { Some((files, text)) => (files.trim(), text.trim()), None => (args, ""), }; - let files = shell_words::split(files).with_context(|| "Invalid files")?; - let mut texts = vec![]; - if !text.is_empty() { - texts.push(text.to_string()); - } - for file_path in files.into_iter() { - let mut text = String::new(); - let mut file = std::fs::File::open(&file_path) - .with_context(|| format!("Unable to open file '{file_path}'"))?; - file.read_to_string(&mut text) - .with_context(|| format!("Unable to read file '{file_path}'"))?; - texts.push(text); - } - let content = texts.join("\n"); - self.ask(&content)?; + let files = shell_words::split(files).with_context(|| "Invalid args")?; + self.ask(text, files)?; } - None => println!("Usage: .read ...[ -- ...]"), + None => println!("Usage: .file ...[ -- ...]"), }, ".exit" => match args { Some("role") => { @@ -233,7 +222,7 @@ impl Repl { _ => unknown_command()?, }, None => { - self.ask(line)?; + self.ask(line, vec![])?; } } @@ -242,13 +231,18 @@ impl Repl { Ok(false) } - fn ask(&self, input: &str) -> Result<()> { - if input.is_empty() { + fn ask(&self, text: &str, files: Vec) -> Result<()> { + if text.is_empty() && files.is_empty() { return Ok(()); } - self.config.read().maybe_print_send_tokens(input); + let input = if files.is_empty() { + Input::from_str(text) + } else { + Input::new(text, files)? + }; + self.config.read().maybe_print_send_tokens(&input); let client = init_client(&self.config)?; - let output = render_stream(input, client.as_ref(), &self.config, self.abort.clone())?; + let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?; self.config.write().save_message(input, &output)?; if self.config.read().auto_copy { let _ = self.copy(&output); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0039298..8a80b23 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -8,6 +8,8 @@ pub use self::clipboard::set_text; pub use self::prompt_input::*; pub use self::tiktoken::cl100k_base_singleton; +use sha2::{Digest, Sha256}; + pub fn now() -> String { let now = chrono::Local::now(); now.to_rfc3339_opts(chrono::SecondsFormat::Secs, false) @@ -76,6 +78,13 @@ pub fn init_tokio_runtime() -> anyhow::Result { .with_context(|| "Failed to init tokio") } +pub fn sha256sum(input: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(input); + let result = hasher.finalize(); + format!("{:x}", result) +} + #[cfg(test)] mod tests { use super::*;