You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
aichat/src/config/session.rs

482 lines
15 KiB
Rust

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

use super::input::*;
use super::*;
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use inquire::{validator::Validation, Confirm, Text};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{self, create_dir_all, read_to_string};
use std::path::Path;
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Session {
#[serde(rename(serialize = "model", deserialize = "model"))]
model_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
functions_filter: Option<FunctionsFilter>,
#[serde(skip_serializing_if = "Option::is_none")]
save_session: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
compress_threshold: Option<usize>,
messages: Vec<Message>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
data_urls: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
compressed_messages: Vec<Message>,
#[serde(skip)]
model: Model,
#[serde(skip)]
role_prompt: String,
#[serde(skip)]
role_name: String,
#[serde(skip)]
name: String,
#[serde(skip)]
path: Option<String>,
#[serde(skip)]
dirty: bool,
#[serde(skip)]
compressing: bool,
}
impl Session {
pub fn new(config: &Config, name: &str) -> Self {
let save_session = if name == TEMP_SESSION_NAME {
None
} else {
config.save_session
};
let role = config.extract_role();
let mut session = Self {
name: name.to_string(),
save_session,
..Default::default()
};
session.set_role(role);
session.dirty = false;
session
}
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
let content = read_to_string(path)
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
let mut session: Self =
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {}", name))?;
session.model = Model::retrieve_chat(config, &session.model_id)?;
session.name = name.to_string();
session.path = Some(path.display().to_string());
if let Some(bot) = &config.bot {
session
.role_prompt
.clone_from(&bot.definition().instructions);
}
Ok(session)
}
pub fn is_temp(&self) -> bool {
self.name == TEMP_SESSION_NAME
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty() && self.compressed_messages.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn dirty(&self) -> bool {
self.dirty
}
pub fn compressing(&self) -> bool {
self.compressing
}
pub fn save_session(&self) -> Option<bool> {
self.save_session
}
pub fn need_compress(&self, current_compress_threshold: usize) -> bool {
let threshold = self
.compress_threshold
.unwrap_or(current_compress_threshold);
threshold >= 1000 && self.tokens() > threshold
}
pub fn tokens(&self) -> usize {
self.model().total_tokens(&self.messages)
}
pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
pub fn export(&self) -> Result<String> {
let mut data = json!({
"path": self.path,
"model": self.model().id(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into();
}
if let Some(functions_filter) = self.functions_filter() {
data["functions_filter"] = functions_filter.into();
}
if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into();
}
let (tokens, percent) = self.tokens_usage();
data["total_tokens"] = tokens.into();
if let Some(max_input_tokens) = self.model().max_input_tokens() {
data["max_input_tokens"] = max_input_tokens.into();
}
if percent != 0.0 {
data["total/max"] = format!("{}%", percent).into();
}
data["messages"] = json!(self.messages);
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about session '{}'", &self.name))?;
Ok(output)
}
pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
let mut items = vec![];
if let Some(path) = &self.path {
items.push(("path", path.to_string()));
}
items.push(("model", self.model().id()));
if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
}
if let Some(top_p) = self.top_p() {
items.push(("top_p", top_p.to_string()));
}
if let Some(functions_filter) = self.functions_filter() {
items.push(("functions_filter", functions_filter));
}
if let Some(save_session) = self.save_session() {
items.push(("save_session", save_session.to_string()));
}
if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_input_tokens) = self.model().max_input_tokens() {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
let mut lines: Vec<String> = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect();
if !self.is_empty() {
lines.push("".into());
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
for message in &self.messages {
match message.role {
MessageRole::System => {
lines.push(render.render(&message.content.render_input(resolve_url_fn)));
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
lines.push(render.render(text));
}
lines.push("".into());
}
MessageRole::User => {
lines.push(format!(
"{}{}",
self.name,
message.content.render_input(resolve_url_fn)
));
}
}
}
}
if lines.last() == Some(&String::new()) {
lines.pop();
}
let output = lines.join("\n");
Ok(output)
}
pub fn tokens_usage(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_input_tokens = self.model().max_input_tokens().unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)
}
pub fn set_role(&mut self, role: Role) {
self.model_id = role.model().id();
self.temperature = role.temperature();
self.top_p = role.top_p();
self.functions_filter = role.functions_filter();
self.model = role.model().clone();
self.role_name = role.name().to_string();
self.role_prompt = role.prompt().to_string();
self.dirty = true;
}
pub fn clear_role(&mut self) {
self.role_name.clear();
self.role_prompt.clear();
}
pub fn set_save_session(&mut self, value: Option<bool>) {
if self.name == TEMP_SESSION_NAME {
return;
}
if self.save_session != value {
self.save_session = value;
self.dirty = true;
}
}
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
if self.compress_threshold != value {
self.compress_threshold = value;
self.dirty = true;
}
}
pub fn set_compressing(&mut self, compressing: bool) {
self.compressing = compressing;
}
pub fn compress(&mut self, prompt: String) {
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message::new(
MessageRole::System,
MessageContent::Text(prompt),
));
self.dirty = true;
}
pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> {
let save_session = self.save_session();
if self.dirty && save_session != Some(false) {
if save_session.is_none() {
if !is_repl {
return Ok(());
}
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
if !ans {
return Ok(());
}
if self.is_temp() {
self.name = Text::new("Session name:")
.with_validator(|input: &str| {
if input == TEMP_SESSION_NAME {
Ok(Validation::Invalid(format!("'{TEMP_SESSION_NAME}' is a reserved word and cannot be used as a session name").into()))
} else if input.trim().is_empty() {
Ok(Validation::Invalid("This field is required".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
}
let session_path = session_dir.join(format!("{}.yaml", self.name()));
self.save(&session_path, is_repl)?;
}
Ok(())
}
pub fn save(&mut self, session_path: &Path, is_repl: bool) -> Result<()> {
if let Some(sessions_dir) = session_path.parent() {
if !sessions_dir.exists() {
create_dir_all(sessions_dir).with_context(|| {
format!("Failed to create session_dir '{}'", sessions_dir.display())
})?;
}
}
self.path = Some(session_path.display().to_string());
let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session {}", self.name))?;
fs::write(session_path, content).with_context(|| {
format!(
"Failed to write session {} to {}",
self.name,
session_path.display()
)
})?;
if is_repl {
println!("✨ Saved session to '{}'", session_path.display());
}
self.dirty = false;
Ok(())
}
pub fn guard_empty(&self) -> Result<()> {
if !self.is_empty() {
bail!("Cannot perform this action in a session with messages")
}
Ok(())
}
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
if input.continue_output().is_some() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
}
} else if input.regenerate() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = output.to_string();
}
}
} else {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
}
if need_add_msg {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
self.dirty = true;
Ok(())
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.dirty = true;
}
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_messages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
if input.continue_output().is_some() {
return messages;
} else if input.regenerate() {
messages.pop();
return messages;
}
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
messages = input.role().build_messages(input);
need_add_msg = false;
} else if len == 1 && self.compressed_messages.len() >= 2 {
messages
.extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec());
}
if need_add_msg {
messages.push(Message::new(MessageRole::User, input.message_content()));
}
messages
}
}
impl RoleLike for Session {
fn to_role(&self) -> Role {
let mut role = Role::new(&self.role_name, &self.role_prompt);
role.sync(self);
role
}
fn model(&self) -> &Model {
&self.model
}
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
fn top_p(&self) -> Option<f64> {
self.top_p
}
fn functions_filter(&self) -> Option<FunctionsFilter> {
self.functions_filter.clone()
}
fn set_model(&mut self, model: &Model) {
if self.model().id() != model.id() {
self.model_id = model.id();
self.model = model.clone();
self.dirty = true;
}
}
fn set_temperature(&mut self, value: Option<f64>) {
if self.temperature != value {
self.temperature = value;
self.dirty = true;
}
}
fn set_top_p(&mut self, value: Option<f64>) {
if self.top_p != value {
self.top_p = value;
self.dirty = true;
}
}
fn set_functions_filter(&mut self, value: Option<FunctionsFilter>) {
if self.functions_filter != value {
self.functions_filter = value;
self.dirty = true;
}
}
}