use super::input::resolve_data_url; use super::role::Role; use super::{Input, Model}; use crate::client::{Message, MessageContent, MessageRole}; use crate::render::MarkdownRender; use anyhow::{bail, Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use std::fs::{self, read_to_string}; use std::path::Path; pub const TEMP_SESSION_NAME: &str = "temp"; #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct Session { #[serde(rename(serialize = "model", deserialize = "model"))] model_id: String, temperature: Option, messages: Vec, #[serde(default)] data_urls: HashMap, #[serde(skip)] pub name: String, #[serde(skip)] pub path: Option, #[serde(skip)] pub dirty: bool, #[serde(skip)] pub role: Option, #[serde(skip)] pub model: Model, } impl Session { pub fn new(name: &str, model: Model, role: Option) -> Self { let temperature = role.as_ref().and_then(|v| v.temperature); Self { model_id: model.id(), temperature, messages: vec![], data_urls: Default::default(), name: name.to_string(), path: None, dirty: false, role, model, } } pub fn load(name: &str, path: &Path) -> Result { let content = read_to_string(path) .with_context(|| format!("Failed to load session {} at {}", name, path.display()))?; let mut session: Self = serde_yaml::from_str(&content).with_context(|| format!("Invalid session {}", name))?; session.name = name.to_string(); session.path = Some(path.display().to_string()); Ok(session) } pub fn name(&self) -> &str { &self.name } pub fn model(&self) -> &str { &self.model_id } pub fn temperature(&self) -> Option { self.temperature } pub fn tokens(&self) -> usize { self.model.total_tokens(&self.messages) } pub fn export(&self) -> Result { self.guard_save()?; let (tokens, percent) = self.tokens_and_percent(); let mut data = json!({ "path": self.path, "model": self.model(), }); if let Some(temperature) = self.temperature() { data["temperature"] = temperature.into(); } data["total_tokens"] = tokens.into(); if let Some(max_tokens) = self.model.max_tokens { data["max_tokens"] = max_tokens.into(); } if percent != 0.0 { data["total/max"] = format!("{}%", percent).into(); } data["messages"] = json!(self.messages); let output = serde_yaml::to_string(&data) .with_context(|| format!("Unable to show info about session {}", &self.name))?; Ok(output) } pub fn render(&self, render: &mut MarkdownRender) -> Result { let mut items = vec![]; if let Some(path) = &self.path { items.push(("path", path.to_string())); } items.push(("model", self.model.id())); if let Some(temperature) = self.temperature() { items.push(("temperature", temperature.to_string())); } if let Some(max_tokens) = self.model.max_tokens { items.push(("max_tokens", max_tokens.to_string())); } let mut lines: Vec = items .iter() .map(|(name, value)| format!("{name:<20}{value}")) .collect(); if !self.is_empty() { lines.push("".into()); let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string()); for message in &self.messages { match message.role { MessageRole::System => { continue; } MessageRole::Assistant => { if let MessageContent::Text(text) = &message.content { lines.push(render.render(text)); } lines.push("".into()); } MessageRole::User => { lines.push(format!( "{}){}", self.name, message.content.render_input(resolve_url_fn) )); } } } } let output = lines.join("\n"); Ok(output) } pub fn tokens_and_percent(&self) -> (usize, f32) { let tokens = self.tokens(); let max_tokens = self.model.max_tokens.unwrap_or_default(); let percent = if max_tokens == 0 { 0.0 } else { let percent = tokens as f32 / max_tokens as f32 * 100.0; (percent * 100.0).round() / 100.0 }; (tokens, percent) } pub fn update_role(&mut self, role: Option) -> Result<()> { self.guard_empty()?; self.temperature = role.as_ref().and_then(|v| v.temperature); self.role = role; Ok(()) } pub fn set_temperature(&mut self, value: Option) { self.temperature = value; } pub fn set_model(&mut self, model: Model) -> Result<()> { self.model_id = model.id(); self.model = model; Ok(()) } pub fn save(&mut self, session_path: &Path) -> Result<()> { if !self.should_save() { return Ok(()); } self.path = Some(session_path.display().to_string()); let content = serde_yaml::to_string(&self) .with_context(|| format!("Failed to serde session {}", self.name))?; fs::write(session_path, content).with_context(|| { format!( "Failed to write session {} to {}", self.name, session_path.display() ) })?; self.dirty = false; Ok(()) } pub fn should_save(&self) -> bool { !self.is_empty() && self.dirty } pub fn guard_save(&self) -> Result<()> { if self.path.is_none() { bail!("Not found session '{}'", self.name) } Ok(()) } pub fn guard_empty(&self) -> Result<()> { if !self.is_empty() { bail!("Cannot perform this action in a session with messages") } Ok(()) } pub fn is_temp(&self) -> bool { self.name == TEMP_SESSION_NAME } pub fn is_empty(&self) -> bool { self.messages.is_empty() } pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> { let mut need_add_msg = true; if self.messages.is_empty() { if let Some(role) = self.role.as_ref() { self.messages.extend(role.build_messages(input)); need_add_msg = false; } } if need_add_msg { self.messages.push(Message { role: MessageRole::User, content: input.to_message_content(), }); } self.data_urls.extend(input.data_urls()); self.messages.push(Message { role: MessageRole::Assistant, content: MessageContent::Text(output.to_string()), }); self.role = None; self.dirty = true; Ok(()) } pub fn echo_messages(&self, input: &Input) -> String { let messages = self.build_emssages(input); serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) } pub fn build_emssages(&self, input: &Input) -> Vec { let mut messages = self.messages.clone(); let mut need_add_msg = true; if messages.is_empty() { if let Some(role) = self.role.as_ref() { messages = role.build_messages(input); need_add_msg = false; } }; if need_add_msg { messages.push(Message { role: MessageRole::User, content: input.to_message_content(), }); } messages } }