You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
aichat/src/config/session.rs

352 lines
11 KiB
Rust

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

use super::input::resolve_data_url;
use super::{Config, Input, Model};
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{self, read_to_string};
use std::path::Path;
pub const TEMP_SESSION_NAME: &str = "temp";
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Session {
#[serde(rename(serialize = "model", deserialize = "model"))]
model_id: String,
temperature: Option<f64>,
top_p: Option<f64>,
#[serde(default)]
save_session: Option<bool>,
messages: Vec<Message>,
#[serde(default)]
data_urls: HashMap<String, String>,
#[serde(default)]
compressed_messages: Vec<Message>,
compress_threshold: Option<usize>,
#[serde(skip)]
pub name: String,
#[serde(skip)]
pub path: Option<String>,
#[serde(skip)]
pub dirty: bool,
#[serde(skip)]
pub compressing: bool,
#[serde(skip)]
pub model: Model,
}
impl Session {
pub fn new(config: &Config, name: &str) -> Self {
Self {
model_id: config.model.id(),
temperature: config.temperature,
top_p: config.top_p,
save_session: config.save_session,
messages: vec![],
compressed_messages: vec![],
compress_threshold: None,
data_urls: Default::default(),
name: name.to_string(),
path: None,
dirty: false,
compressing: false,
model: config.model.clone(),
}
}
pub fn load(name: &str, path: &Path) -> Result<Self> {
let content = read_to_string(path)
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
let mut session: Self =
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {}", name))?;
session.name = name.to_string();
session.path = Some(path.display().to_string());
Ok(session)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn model(&self) -> &str {
&self.model_id
}
pub fn temperature(&self) -> Option<f64> {
self.temperature
}
pub fn top_p(&self) -> Option<f64> {
self.top_p
}
pub fn save_session(&self) -> Option<bool> {
self.save_session
}
pub fn need_compress(&self, current_compress_threshold: usize) -> bool {
let threshold = self
.compress_threshold
.unwrap_or(current_compress_threshold);
threshold >= 1000 && self.tokens() > threshold
}
pub fn tokens(&self) -> usize {
self.model.total_tokens(&self.messages)
}
pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
pub fn export(&self) -> Result<String> {
if self.path.is_none() {
bail!("Not found session '{}'", self.name)
}
let (tokens, percent) = self.tokens_and_percent();
let mut data = json!({
"path": self.path,
"model": self.model(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into();
}
if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into();
}
data["total_tokens"] = tokens.into();
if let Some(conext_window) = self.model.max_input_tokens {
data["max_input_tokens"] = conext_window.into();
}
if percent != 0.0 {
data["total/max"] = format!("{}%", percent).into();
}
data["messages"] = json!(self.messages);
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about session {}", &self.name))?;
Ok(output)
}
pub fn info(&self, render: &mut MarkdownRender) -> Result<String> {
let mut items = vec![];
if let Some(path) = &self.path {
items.push(("path", path.to_string()));
}
items.push(("model", self.model.id()));
if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
}
if let Some(top_p) = self.top_p() {
items.push(("top_p", top_p.to_string()));
}
if let Some(save_session) = self.save_session() {
items.push(("save_session", save_session.to_string()));
}
if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_input_tokens) = self.model.max_input_tokens {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
let mut lines: Vec<String> = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect();
if !self.is_empty() {
lines.push("".into());
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
for message in &self.messages {
match message.role {
MessageRole::System => {
lines.push(render.render(&message.content.render_input(resolve_url_fn)));
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
lines.push(render.render(text));
}
lines.push("".into());
}
MessageRole::User => {
lines.push(format!(
"{}{}",
self.name,
message.content.render_input(resolve_url_fn)
));
}
}
}
}
let output = lines.join("\n");
Ok(output)
}
pub fn tokens_and_percent(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_input_tokens = self.model.max_input_tokens.unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)
}
pub fn set_temperature(&mut self, value: Option<f64>) {
if self.temperature != value {
self.temperature = value;
self.dirty = true;
}
}
pub fn set_top_p(&mut self, value: Option<f64>) {
if self.top_p != value {
self.top_p = value;
self.dirty = true;
}
}
pub fn set_save_session(&mut self, value: Option<bool>) {
if self.save_session != value {
self.save_session = value;
self.dirty = true;
}
}
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
if self.compress_threshold != value {
self.compress_threshold = value;
self.dirty = true;
}
}
pub fn set_model(&mut self, model: Model) -> Result<()> {
let model_id = model.id();
if self.model_id != model_id {
self.model_id = model_id;
self.dirty = true;
}
self.model = model;
Ok(())
}
pub fn compress(&mut self, prompt: String) {
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(prompt),
});
self.dirty = true;
}
pub fn save(&mut self, session_path: &Path) -> Result<()> {
self.path = Some(session_path.display().to_string());
let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session {}", self.name))?;
fs::write(session_path, content).with_context(|| {
format!(
"Failed to write session {} to {}",
self.name,
session_path.display()
)
})?;
self.dirty = false;
Ok(())
}
pub fn guard_empty(&self) -> Result<()> {
if !self.is_empty() {
bail!("Cannot perform this action in a session with messages")
}
Ok(())
}
pub fn is_temp(&self) -> bool {
self.name == TEMP_SESSION_NAME
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty() && self.compressed_messages.is_empty()
}
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
if let Some(role) = input.role() {
self.messages.extend(role.build_messages(input));
need_add_msg = false;
}
}
if need_add_msg {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_message_content(),
});
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message {
role: MessageRole::Assistant,
content: MessageContent::Text(output.to_string()),
});
self.dirty = true;
Ok(())
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.dirty = true;
}
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_emssages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_emssages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
if let Some(role) = input.role() {
messages = role.build_messages(input);
need_add_msg = false;
}
} else if len == 1 && self.compressed_messages.len() >= 2 {
messages
.extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec());
}
if need_add_msg {
messages.push(Message {
role: MessageRole::User,
content: input.to_message_content(),
});
}
messages
}
}