feat: support vision (#249)

* feat: support vision

* clippy

* implement vision

* resolve data url to local file

* add model openai:gpt-4-vision-preview

* use newline to concate embeded text files

* set max_tokens for gpt-4-vision-preview
pull/250/head
sigoden 7 months ago committed by GitHub
parent 5bfe95d311
commit 35c75506e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

100
Cargo.lock generated

@ -49,6 +49,7 @@ dependencies = [
"is-terminal",
"lazy_static",
"log",
"mime_guess",
"nu-ansi-term",
"parking_lot",
"reedline",
@ -58,6 +59,7 @@ dependencies = [
"serde",
"serde_json",
"serde_yaml",
"sha2",
"shell-words",
"simplelog",
"syntect",
@ -246,6 +248,15 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bstr"
version = "1.7.0"
@ -377,6 +388,15 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0"
dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.3.2"
@ -495,6 +515,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "deranged"
version = "0.3.9"
@ -504,6 +534,16 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]]
name = "dirs"
version = "5.0.1"
@ -696,6 +736,16 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "gethostname"
version = "0.2.3"
@ -883,9 +933,9 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.0.2"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897"
checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f"
dependencies = [
"equivalent",
"hashbrown 0.14.2",
@ -1030,6 +1080,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1623,7 +1683,7 @@ version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
dependencies = [
"indexmap 2.0.2",
"indexmap 2.1.0",
"itoa",
"ryu",
"serde",
@ -1647,13 +1707,24 @@ version = "0.9.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cc7a1570e38322cfe4154732e5110f887ea57e22b76f4bfd32b5bdd3368666c"
dependencies = [
"indexmap 2.0.2",
"indexmap 2.1.0",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "shell-words"
version = "1.1.0"
@ -2032,6 +2103,21 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicase"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.13"
@ -2100,6 +2186,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "vte"
version = "0.10.1"

@ -41,6 +41,8 @@ reqwest-eventsource = "0.5.0"
simplelog = "0.12.1"
log = "0.4.20"
shell-words = "1.1.0"
mime_guess = "2.0.4"
sha2 = "0.10.8"
[dependencies.reqwest]
version = "0.11.14"

@ -39,6 +39,7 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases),
- Support chat and command modes
- Use [Roles](#roles)
- Powerful [Chat REPL](#chat-repl)
- Support vision
- Context-aware conversation/session
- Syntax highlighting markdown and 200 other languages
- Stream output with hand-typing effect
@ -147,9 +148,9 @@ The Chat REPL supports:
.session Start a context-aware chat session
.info session Show session info
.exit session End the current session
.file Attach files to the message and then submit it
.set Modify the configuration parameters
.copy Copy the last reply to the clipboard
.read Read files into the message and submit
.exit Exit the REPL
Type ::: to begin multi-line editing, type ::: to end it.
@ -255,6 +256,17 @@ The prompt on the right side is about the current usage of tokens and the propor
compared to the maximum number of tokens allowed by the model.
### `.file` - attach files to the message
```
Usage: .file <file>... [-- text...]
.file message.txt
.file config.yaml -- convert to toml
.file a.jpg b.jpg -- Whats in these images?
.file https://ibb.co/a.png https://ibb.co/b.png -- what is the difference?
```
### `.set` - modify the configuration temporarily
```
@ -277,6 +289,7 @@ Options:
-m, --model <MODEL> Choose a LLM model
-r, --role <ROLE> Choose a role
-s, --session [<SESSION>] Create or reuse a session
-f, --file <FILE>... Attach files to the message to be sent
-H, --no-highlight Disable syntax highlighting
-S, --no-stream No stream output
-w, --wrap <WRAP> Specify the text-wrapping mode (no*, auto, <max-width>)
@ -306,6 +319,9 @@ cat config.json | aichat convert to yaml # Read stdin
cat config.json | aichat -r convert:yaml # Read stdin with a role
cat config.json | aichat -s i18n # Read stdin with a session
aichat --file a.png b.png -- diff images # Attach files
aichat --file screenshot.png -r ocr # Attach files with a role
aichat --list-models # List all available models
aichat --list-roles # List all available roles
aichat --list-sessions # List all available models

@ -12,6 +12,9 @@ pub struct Cli {
/// Create or reuse a session
#[clap(short = 's', long)]
pub session: Option<Option<String>>,
/// Attach files to the message to be sent.
#[clap(short = 'f', long, num_args = 1.., value_name = "FILE")]
pub file: Option<Vec<String>>,
/// Disable syntax highlighting
#[clap(short = 'H', long)]
pub no_highlight: bool,

@ -1,7 +1,7 @@
use super::{openai::OpenAIConfig, ClientConfig, Message};
use crate::{
config::GlobalConfig,
config::{GlobalConfig, Input},
render::ReplyHandler,
utils::{
init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, AbortSignal,
@ -50,7 +50,7 @@ macro_rules! register_client {
}
impl $client {
pub const NAME: &str = $name;
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig) -> Option<Box<dyn Client>> {
let model = global_config.read().model.clone();
@ -186,22 +186,22 @@ pub trait Client {
Ok(client)
}
fn send_message(&self, content: &str) -> Result<String> {
fn send_message(&self, input: Input) -> Result<String> {
init_tokio_runtime()?.block_on(async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(content);
let content = global_config.read().echo_messages(&input);
return Ok(content);
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(content, false)?;
let data = global_config.read().prepare_send_data(&input, false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
})
}
fn send_message_streaming(&self, content: &str, handler: &mut ReplyHandler) -> Result<()> {
fn send_message_streaming(&self, input: &Input, handler: &mut ReplyHandler) -> Result<()> {
async fn watch_abort(abort: AbortSignal) {
loop {
if abort.aborted() {
@ -211,12 +211,13 @@ pub trait Client {
}
}
let abort = handler.get_abort();
init_tokio_runtime()?.block_on(async {
let input = input.clone();
init_tokio_runtime()?.block_on(async move {
tokio::select! {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(content);
let content = global_config.read().echo_messages(&input);
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(10)).await;
@ -225,7 +226,7 @@ pub trait Client {
return Ok(());
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(content, true)?;
let data = global_config.read().prepare_send_data(&input, true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;

@ -1,4 +1,4 @@
use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model};
use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model, MessageContent};
use crate::{
config::GlobalConfig,
@ -198,8 +198,10 @@ fn build_body(data: SendData, _model: String) -> Value {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
if let Some(message) = messages.get_mut(0) {
message.content = format!("{}\n\n{}", system_message.content, message.content)
if let (Some(message), MessageContent::Text(system_text)) = (messages.get_mut(0), system_message.content) {
if let MessageContent::Text(text) = message.content.clone() {
message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text))
}
}
}

@ -1,16 +1,18 @@
use crate::config::Input;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
pub content: MessageContent,
}
impl Message {
pub fn new(content: &str) -> Self {
pub fn new(input: &Input) -> Self {
Self {
role: MessageRole::User,
content: content.to_string(),
content: input.to_message_content(),
}
}
}
@ -38,6 +40,65 @@ impl MessageRole {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
}
impl MessageContent {
pub fn render_input(&self, resolve_url_fn: impl Fn(&str) -> String) -> String {
match self {
MessageContent::Text(text) => text.to_string(),
MessageContent::Array(list) => {
let (mut concated_text, mut files) = (String::new(), vec![]);
for item in list {
match item {
MessageContentPart::Text { text } => {
concated_text = format!("{concated_text} {text}")
}
MessageContentPart::ImageUrl { image_url } => {
files.push(resolve_url_fn(&image_url.url))
}
}
}
if !concated_text.is_empty() {
concated_text = format!(" -- {concated_text}")
}
format!(".file {}{}", files.join(" "), concated_text)
}
}
}
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
match self {
MessageContent::Text(text) => *text = replace_fn(text),
MessageContent::Array(list) => {
if list.is_empty() {
list.push(MessageContentPart::Text {
text: replace_fn(""),
})
} else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) {
*text = replace_fn(text)
}
}
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
}
#[cfg(test)]
mod tests {
use super::*;
@ -45,7 +106,7 @@ mod tests {
#[test]
fn test_serde() {
assert_eq!(
serde_json::to_string(&Message::new("Hello World")).unwrap(),
serde_json::to_string(&Message::new(&Input::from_str("Hello World"))).unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}"
);
}

@ -1,4 +1,4 @@
use super::message::Message;
use super::message::{Message, MessageContent};
use crate::utils::count_tokens;
@ -79,7 +79,15 @@ impl Model {
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages.iter().map(|v| count_tokens(&v.content)).sum()
messages
.iter()
.map(|v| {
match &v.content {
MessageContent::Text(text) => count_tokens(text),
MessageContent::Array(_) => 0, // TODO
}
})
.sum()
}
pub fn total_tokens(&self, messages: &[Message]) -> usize {

@ -19,13 +19,14 @@ use std::env;
const API_BASE: &str = "https://api.openai.com/v1";
const MODELS: [(&str, usize); 6] = [
const MODELS: [(&str, usize); 7] = [
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16385),
("gpt-3.5-turbo-1106", 16385),
("gpt-4-1106-preview", 128000),
("gpt-4-vision-preview", 128000),
("gpt-4", 8192),
("gpt-4-32k", 32768),
("gpt-4-1106-preview", 128000),
];
pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -145,6 +146,12 @@ pub fn openai_build_body(data: SendData, model: String) -> Value {
"model": model,
"messages": messages,
});
// The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger
if model == "gpt-4-vision-preview" {
body["max_tokens"] = json!(4096);
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}

@ -1,4 +1,4 @@
use super::{PaLMClient, Client, ExtraConfig, Model, PromptType, SendData, TokensCountFactors, send_message_as_streaming};
use super::{PaLMClient, Client, ExtraConfig, Model, PromptType, SendData, TokensCountFactors, send_message_as_streaming, MessageContent};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
@ -115,8 +115,10 @@ fn build_body(data: SendData, _model: String) -> Value {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
if let Some(message) = messages.get_mut(0) {
message.content = format!("{}\n\n{}", system_message.content, message.content)
if let (Some(message), MessageContent::Text(system_text)) = (messages.get_mut(0), system_message.content) {
if let MessageContent::Text(text) = message.content.clone() {
message.content = MessageContent::Text(format!("{}\n\n{}", system_text, text))
}
}
}

@ -0,0 +1,162 @@
use crate::client::{ImageUrl, MessageContent, MessageContentPart};
use crate::utils::sha256sum;
use anyhow::{bail, Context, Result};
use base64::{self, engine::general_purpose::STANDARD, Engine};
use mime_guess::from_path;
use std::{
collections::HashMap,
fs::{self, File},
io::Read,
path::{Path, PathBuf},
};
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
#[derive(Debug, Clone)]
pub struct Input {
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
}
impl Input {
pub fn from_str(text: &str) -> Self {
Self {
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
}
}
pub fn new(text: &str, files: Vec<String>) -> Result<Self> {
let mut texts = vec![text.to_string()];
let mut medias = vec![];
let mut data_urls = HashMap::new();
for file_item in files.into_iter() {
match resolve_path(&file_item) {
Some(file_path) => {
let file_path = fs::canonicalize(file_path)
.with_context(|| format!("Unable to use file '{file_item}"))?;
if is_image_ext(&file_path) {
let data_url = read_media_to_data_url(&file_path)?;
data_urls.insert(sha256sum(&data_url), file_path.display().to_string());
medias.push(data_url)
} else {
let mut text = String::new();
let mut file = File::open(&file_path)
.with_context(|| format!("Unable to open file '{file_item}'"))?;
file.read_to_string(&mut text)
.with_context(|| format!("Unable to read file '{file_item}'"))?;
texts.push(text);
}
}
None => {
if is_image_ext(Path::new(&file_item)) {
medias.push(file_item)
} else {
bail!("Unable to use file '{file_item}");
}
}
}
}
Ok(Self {
text: texts.join("\n"),
medias,
data_urls,
})
}
pub fn data_urls(&self) -> HashMap<String, String> {
self.data_urls.clone()
}
pub fn render(&self) -> String {
if self.medias.is_empty() {
return self.text.clone();
}
let text = if self.text.is_empty() {
self.text.to_string()
} else {
format!(" -- {}", self.text)
};
let files: Vec<String> = self
.medias
.iter()
.cloned()
.map(|url| resolve_data_url(&self.data_urls, url))
.collect();
format!(".file {}{}", files.join(" "), text)
}
pub fn to_message_content(&self) -> MessageContent {
if self.medias.is_empty() {
MessageContent::Text(self.text.clone())
} else {
let mut list: Vec<MessageContentPart> = self
.medias
.iter()
.cloned()
.map(|url| MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
})
.collect();
if !self.text.is_empty() {
list.insert(
0,
MessageContentPart::Text {
text: self.text.clone(),
},
);
}
MessageContent::Array(list)
}
}
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256sum(&data_url);
if let Some(path) = data_urls.get(&hash) {
return path.to_string();
}
data_url
} else {
data_url
}
}
fn resolve_path(file: &str) -> Option<PathBuf> {
if ["https://", "http://", "data:"]
.iter()
.any(|v| file.starts_with(v))
{
return None;
}
let path = if let (Some(file), Some(home)) = (file.strip_prefix('~'), dirs::home_dir()) {
home.join(file)
} else {
std::env::current_dir().ok()?.join(file)
};
Some(path)
}
fn is_image_ext(path: &Path) -> bool {
path.extension()
.map(|v| IMAGE_EXTS.iter().any(|ext| *ext == v.to_string_lossy()))
.unwrap_or_default()
}
fn read_media_to_data_url<P: AsRef<Path>>(image_path: P) -> Result<String> {
let mime_type = from_path(&image_path).first_or_octet_stream().to_string();
let mut file = File::open(image_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded_image = STANDARD.encode(buffer);
let data_url = format!("data:{};base64,{}", mime_type, encoded_image);
Ok(data_url)
}

@ -1,6 +1,8 @@
mod input;
mod role;
mod session;
pub use self::input::Input;
use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};
@ -78,7 +80,7 @@ pub struct Config {
#[serde(skip)]
pub model: Model,
#[serde(skip)]
pub last_message: Option<(String, String)>,
pub last_message: Option<(Input, String)>,
#[serde(skip)]
pub temperature: Option<f64>,
}
@ -200,15 +202,15 @@ impl Config {
Ok(path)
}
pub fn save_message(&mut self, input: &str, output: &str) -> Result<()> {
self.last_message = Some((input.to_string(), output.to_string()));
pub fn save_message(&mut self, input: Input, output: &str) -> Result<()> {
self.last_message = Some((input.clone(), output.to_string()));
if self.dry_run {
return Ok(());
}
if let Some(session) = self.session.as_mut() {
session.add_message(input, output)?;
session.add_message(&input, output)?;
return Ok(());
}
@ -220,13 +222,14 @@ impl Config {
return Ok(());
}
let timestamp = now();
let input_markdown = input.render();
let output = match self.role.as_ref() {
None => {
format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",)
format!("# CHAT:[{timestamp}]\n{input_markdown}\n--------\n{output}\n--------\n\n",)
}
Some(v) => {
format!(
"# CHAT:[{timestamp}] ({})\n{input}\n--------\n{output}\n--------\n\n",
"# CHAT:[{timestamp}] ({})\n{input_markdown}\n--------\n{output}\n--------\n\n",
v.name,
)
}
@ -292,23 +295,23 @@ impl Config {
Ok(())
}
pub fn echo_messages(&self, content: &str) -> String {
pub fn echo_messages(&self, input: &Input) -> String {
if let Some(session) = self.session.as_ref() {
session.echo_messages(content)
session.echo_messages(input)
} else if let Some(role) = self.role.as_ref() {
role.echo_messages(content)
role.echo_messages(input)
} else {
content.to_string()
input.render()
}
}
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
pub fn build_messages(&self, input: &Input) -> Result<Vec<Message>> {
let messages = if let Some(session) = self.session.as_ref() {
session.build_emssages(content)
session.build_emssages(input)
} else if let Some(role) = self.role.as_ref() {
role.build_messages(content)
role.build_messages(input)
} else {
let message = Message::new(content);
let message = Message::new(input);
vec![message]
};
Ok(messages)
@ -586,7 +589,7 @@ impl Config {
Ok(dir) => dir,
Err(_) => return vec![],
};
match read_dir(&sessions_dir) {
match read_dir(sessions_dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd.flatten() {
@ -643,8 +646,8 @@ impl Config {
}
}
pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
let messages = self.build_messages(content)?;
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
let messages = self.build_messages(input)?;
self.model.max_tokens_limit(&messages)?;
Ok(SendData {
messages,
@ -653,7 +656,7 @@ impl Config {
})
}
pub fn maybe_print_send_tokens(&self, input: &str) {
pub fn maybe_print_send_tokens(&self, input: &Input) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
let tokens = self.model.total_tokens(&messages);

@ -1,8 +1,10 @@
use crate::client::{Message, MessageRole};
use crate::client::{Message, MessageContent, MessageRole};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use super::Input;
const INPUT_PLACEHOLDER: &str = "__INPUT__";
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -41,17 +43,20 @@ impl Role {
}
}
pub fn echo_messages(&self, content: &str) -> String {
pub fn echo_messages(&self, input: &Input) -> String {
let input_markdown = input.render();
if self.embedded() {
merge_prompt_content(&self.prompt, content)
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
} else {
format!("{}\n\n{content}", self.prompt)
format!("{}\n\n{}", self.prompt, input.render())
}
}
pub fn build_messages(&self, content: &str) -> Vec<Message> {
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut content = input.to_message_content();
if self.embedded() {
let content = merge_prompt_content(&self.prompt, content);
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
vec![Message {
role: MessageRole::User,
content,
@ -60,21 +65,17 @@ impl Role {
vec![
Message {
role: MessageRole::System,
content: self.prompt.clone(),
content: MessageContent::Text(self.prompt.clone()),
},
Message {
role: MessageRole::User,
content: content.to_string(),
content,
},
]
}
}
}
fn merge_prompt_content(prompt: &str, content: &str) -> String {
prompt.replace(INPUT_PLACEHOLDER, content)
}
fn complete_prompt_args(prompt: &str, name: &str) -> String {
let mut prompt = prompt.trim().to_string();
for (i, arg) in name.split(':').skip(1).enumerate() {

@ -1,12 +1,14 @@
use super::input::resolve_data_url;
use super::role::Role;
use super::Model;
use super::{Input, Model};
use crate::client::{Message, MessageRole};
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{self, read_to_string};
use std::path::Path;
@ -18,6 +20,7 @@ pub struct Session {
model_id: String,
temperature: Option<f64>,
messages: Vec<Message>,
data_urls: HashMap<String, String>,
#[serde(skip)]
pub name: String,
#[serde(skip)]
@ -37,6 +40,7 @@ impl Session {
model_id: model.id(),
temperature,
messages: vec![],
data_urls: Default::default(),
name: name.to_string(),
path: None,
dirty: false,
@ -121,6 +125,7 @@ impl Session {
if !self.is_empty() {
lines.push("".into());
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
for message in &self.messages {
match message.role {
@ -128,11 +133,17 @@ impl Session {
continue;
}
MessageRole::Assistant => {
lines.push(render.render(&message.content));
if let MessageContent::Text(text) = &message.content {
lines.push(render.render(text));
}
lines.push("".into());
}
MessageRole::User => {
lines.push(format!("{}{}", self.name, message.content));
lines.push(format!(
"{}{}",
self.name,
message.content.render_input(resolve_url_fn)
));
}
}
}
@ -218,7 +229,7 @@ impl Session {
self.messages.is_empty()
}
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() {
@ -229,35 +240,36 @@ impl Session {
if need_add_msg {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_string(),
content: input.to_message_content(),
});
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message {
role: MessageRole::Assistant,
content: output.to_string(),
content: MessageContent::Text(output.to_string()),
});
self.dirty = true;
Ok(())
}
pub fn echo_messages(&self, content: &str) -> String {
let messages = self.build_emssages(content);
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_emssages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
pub fn build_emssages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
let mut need_add_msg = true;
if messages.is_empty() {
if let Some(role) = self.role.as_ref() {
messages = role.build_messages(content);
messages = role.build_messages(input);
need_add_msg = false;
}
};
if need_add_msg {
messages.push(Message {
role: MessageRole::User,
content: content.into(),
content: input.to_message_content(),
});
}
messages

@ -15,6 +15,7 @@ use crate::config::{Config, GlobalConfig};
use anyhow::Result;
use clap::Parser;
use client::{init_client, list_models};
use config::Input;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use render::{render_error, render_stream, MarkdownRender};
@ -75,18 +76,22 @@ fn main() -> Result<()> {
return Ok(());
}
config.write().onstart()?;
let no_stream = cli.no_stream;
if let Err(err) = start(&config, text, no_stream) {
if let Err(err) = start(&config, text, cli.file, cli.no_stream) {
let highlight = stderr().is_terminal() && config.read().highlight;
render_error(err, highlight)
}
Ok(())
}
fn start(config: &GlobalConfig, text: Option<String>, no_stream: bool) -> Result<()> {
fn start(
config: &GlobalConfig,
text: Option<String>,
include: Option<Vec<String>>,
no_stream: bool,
) -> Result<()> {
if stdin().is_terminal() {
match text {
Some(text) => start_directive(config, &text, no_stream),
Some(text) => start_directive(config, &text, include, no_stream),
None => start_interactive(config),
}
} else {
@ -95,18 +100,24 @@ fn start(config: &GlobalConfig, text: Option<String>, no_stream: bool) -> Result
if let Some(text) = text {
input = format!("{text}\n{input}");
}
start_directive(config, &input, no_stream)
start_directive(config, &input, include, no_stream)
}
}
fn start_directive(config: &GlobalConfig, input: &str, no_stream: bool) -> Result<()> {
fn start_directive(
config: &GlobalConfig,
text: &str,
include: Option<Vec<String>>,
no_stream: bool,
) -> Result<()> {
if let Some(session) = &config.read().session {
session.guard_save()?;
}
let input = Input::new(text, include.unwrap_or_default())?;
let client = init_client(config)?;
config.read().maybe_print_send_tokens(input);
config.read().maybe_print_send_tokens(&input);
let output = if no_stream {
let output = client.send_message(input)?;
let output = client.send_message(input.clone())?;
if stdout().is_terminal() {
let render_options = config.read().get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
@ -117,7 +128,7 @@ fn start_directive(config: &GlobalConfig, input: &str, no_stream: bool) -> Resul
output
} else {
let abort = create_abort_signal();
render_stream(input, client.as_ref(), config, abort)?
render_stream(&input, client.as_ref(), config, abort)?
};
config.write().save_message(input, &output)
}

@ -5,7 +5,7 @@ pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream};
use crate::client::Client;
use crate::config::GlobalConfig;
use crate::config::{GlobalConfig, Input};
use crate::utils::AbortSignal;
use anyhow::{Context, Result};
@ -17,7 +17,7 @@ use std::io::stdout;
use std::thread::spawn;
pub fn render_stream(
input: &str,
input: &Input,
client: &dyn Client,
config: &GlobalConfig,
abort: AbortSignal,

@ -167,7 +167,7 @@ struct Spinner {
}
impl Spinner {
const DATA: [&str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
const DATA: [&'static str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
fn new(message: &str) -> Self {
Spinner {

@ -7,7 +7,7 @@ use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::init_client;
use crate::config::GlobalConfig;
use crate::config::{GlobalConfig, Input};
use crate::render::{render_error, render_stream};
use crate::utils::{create_abort_signal, set_text, AbortSignal};
@ -20,7 +20,6 @@ use reedline::{
ColumnarMenu, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline, ReedlineEvent,
ReedlineMenu, ValidationResult, Validator, Vi,
};
use std::io::Read;
const MENU_NAME: &str = "completion_menu";
@ -34,9 +33,9 @@ const REPL_COMMANDS: [(&str, &str); 13] = [
(".session", "Start a context-aware chat session"),
(".info session", "Show session info"),
(".exit session", "End the current session"),
(".file", "Attach files to the message and then submit it"),
(".set", "Modify the configuration parameters"),
(".copy", "Copy the last reply to the clipboard"),
(".read", "Read files into the message and submit"),
(".exit", "Exit the REPL"),
];
@ -159,7 +158,7 @@ impl Repl {
let old_role =
self.config.read().role.as_ref().map(|v| v.name.to_string());
self.config.write().set_role(name)?;
self.ask(text)?;
self.ask(text, vec![])?;
match old_role {
Some(old_role) => self.config.write().set_role(&old_role)?,
None => self.config.write().clear_role()?,
@ -184,29 +183,19 @@ impl Repl {
self.copy(config.last_reply())
.with_context(|| "Failed to copy the last output")?;
}
".read" => match args {
".read" => {
println!(r#"Deprecated. Use '.read' instead."#);
}
".file" => match args {
Some(args) => {
let (files, text) = match args.split_once(" -- ") {
Some((files, text)) => (files.trim(), text.trim()),
None => (args, ""),
};
let files = shell_words::split(files).with_context(|| "Invalid files")?;
let mut texts = vec![];
if !text.is_empty() {
texts.push(text.to_string());
}
for file_path in files.into_iter() {
let mut text = String::new();
let mut file = std::fs::File::open(&file_path)
.with_context(|| format!("Unable to open file '{file_path}'"))?;
file.read_to_string(&mut text)
.with_context(|| format!("Unable to read file '{file_path}'"))?;
texts.push(text);
}
let content = texts.join("\n");
self.ask(&content)?;
let files = shell_words::split(files).with_context(|| "Invalid args")?;
self.ask(text, files)?;
}
None => println!("Usage: .read <files>...[ -- <text>...]"),
None => println!("Usage: .file <files>...[ -- <text>...]"),
},
".exit" => match args {
Some("role") => {
@ -233,7 +222,7 @@ impl Repl {
_ => unknown_command()?,
},
None => {
self.ask(line)?;
self.ask(line, vec![])?;
}
}
@ -242,13 +231,18 @@ impl Repl {
Ok(false)
}
fn ask(&self, input: &str) -> Result<()> {
if input.is_empty() {
fn ask(&self, text: &str, files: Vec<String>) -> Result<()> {
if text.is_empty() && files.is_empty() {
return Ok(());
}
self.config.read().maybe_print_send_tokens(input);
let input = if files.is_empty() {
Input::from_str(text)
} else {
Input::new(text, files)?
};
self.config.read().maybe_print_send_tokens(&input);
let client = init_client(&self.config)?;
let output = render_stream(input, client.as_ref(), &self.config, self.abort.clone())?;
let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?;
self.config.write().save_message(input, &output)?;
if self.config.read().auto_copy {
let _ = self.copy(&output);

@ -8,6 +8,8 @@ pub use self::clipboard::set_text;
pub use self::prompt_input::*;
pub use self::tiktoken::cl100k_base_singleton;
use sha2::{Digest, Sha256};
pub fn now() -> String {
let now = chrono::Local::now();
now.to_rfc3339_opts(chrono::SecondsFormat::Secs, false)
@ -76,6 +78,13 @@ pub fn init_tokio_runtime() -> anyhow::Result<tokio::runtime::Runtime> {
.with_context(|| "Failed to init tokio")
}
pub fn sha256sum(input: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(input);
let result = hasher.finalize();
format!("{:x}", result)
}
#[cfg(test)]
mod tests {
use super::*;

Loading…
Cancel
Save