From c211d1e43a02266226fb8d04856e703a4e6a625d Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 8 Oct 2024 10:52:44 +0800 Subject: [PATCH] feat: session persists role name (#914) --- src/config/mod.rs | 14 +++++++++----- src/config/session.rs | 29 ++++++++++++++++++----------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 291d477..1c669a7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -839,10 +839,14 @@ impl Config { pub fn role_info(&self) -> Result { if let Some(role) = &self.role { - Ok(role.export()) - } else { - bail!("No role") + return Ok(role.export()); + } else if let Some(session) = &self.session { + let role = session.to_role(); + if !role.name().is_empty() { + return Ok(role.export()); + } } + bail!("No role"); } pub fn exit_role(&mut self) -> Result<()> { @@ -1089,8 +1093,8 @@ impl Config { pub fn clear_session_messages(&mut self) -> Result<()> { if let Some(session) = self.session.as_mut() { session.clear_messages(); - if let Some(prompt) = self.agent.as_ref().map(|v| v.interpolated_instructions()) { - session.update_role_prompt(&prompt); + if let Some(agent) = self.agent.as_ref() { + session.set_agent(agent); } } else { bail!("No session") diff --git a/src/config/session.rs b/src/config/session.rs index 4bda347..c499abc 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -27,19 +27,20 @@ pub struct Session { #[serde(skip_serializing_if = "Option::is_none")] compress_threshold: Option, - messages: Vec, #[serde(default, skip_serializing_if = "HashMap::is_empty")] data_urls: HashMap, #[serde(default, skip_serializing_if = "Vec::is_empty")] compressed_messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + role_name: Option, + + messages: Vec, #[serde(skip)] model: Model, #[serde(skip)] role_prompt: String, #[serde(skip)] - role_name: String, - #[serde(skip)] name: String, #[serde(skip)] path: Option, @@ -72,10 +73,14 @@ impl Session { session.name = name.to_string(); session.path = Some(path.display().to_string()); + if let Some(role_name) = &session.role_name { + if let Ok(role) = config.retrieve_role(role_name) { + session.role_prompt = role.prompt().to_string(); + } + } + if let Some(agent) = &config.agent { - session - .role_prompt - .clone_from(&agent.definition().instructions); + session.set_agent(agent); } Ok(session) @@ -235,17 +240,18 @@ impl Session { self.top_p = role.top_p(); self.use_tools = role.use_tools(); self.model = role.model().clone(); - self.role_name = role.name().to_string(); + self.role_name = Some(role.name().to_string()); self.role_prompt = role.prompt().to_string(); self.dirty = true; } - pub fn update_role_prompt(&mut self, prompt: &str) { - self.role_prompt = prompt.to_string(); + pub fn set_agent(&mut self, agent: &Agent) { + self.role_prompt + .clone_from(&agent.definition().instructions); } pub fn clear_role(&mut self) { - self.role_name.clear(); + self.role_name = None; self.role_prompt.clear(); } @@ -435,7 +441,8 @@ impl Session { impl RoleLike for Session { fn to_role(&self) -> Role { - let mut role = Role::new(&self.role_name, &self.role_prompt); + let role_name = self.role_name.as_deref().unwrap_or_default(); + let mut role = Role::new(role_name, &self.role_prompt); role.sync(self); role }