feat: allow binding model to the role (#505)

pull/506/head
sigoden 5 months ago committed by GitHub
parent 5284a18248
commit 79d0bba640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -70,8 +70,7 @@ macro_rules! register_client {
impl $client { impl $client {
pub const NAME: &'static str = $name; pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig) -> Option<Box<dyn Client>> { pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let model = global_config.read().model.clone();
let config = global_config.read().clients.iter().find_map(|client_config| { let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config { if let ClientConfig::$config(c) = client_config {
if Self::name(c) == &model.client_name { if Self::name(c) == &model.client_name {
@ -84,7 +83,7 @@ macro_rules! register_client {
Some(Box::new(Self { Some(Box::new(Self {
global_config: global_config.clone(), global_config: global_config.clone(),
config, config,
model, model: model.clone(),
})) }))
} }
@ -109,11 +108,12 @@ macro_rules! register_client {
)+ )+
pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result<Box<dyn Client>> { pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
None None
$(.or_else(|| $client::init(config)))+ $(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| { .ok_or_else(|| {
anyhow::anyhow!("Unknown client '{}'", &config.read().model.client_name) anyhow::anyhow!("Unknown client '{}'", model.client_name)
}) })
} }

@ -1,8 +1,8 @@
use super::{role::Role, session::Session, GlobalConfig}; use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{ use crate::client::{
init_client, Client, ImageUrl, Message, MessageContent, MessageContentPart, ModelCapabilities, init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart, Model,
SendData, ModelCapabilities, SendData,
}; };
use crate::utils::{base64_encode, sha256}; use crate::utils::{base64_encode, sha256};
@ -111,8 +111,23 @@ impl Input {
self.text = text; self.text = text;
} }
pub fn model(&self) -> Model {
let model = self.config.read().model.clone();
if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
if model.id() != model_id {
if let Some(model) = list_models(&self.config.read())
.into_iter()
.find(|v| v.id() == model_id)
{
return model.clone();
}
}
};
model
}
pub fn create_client(&self) -> Result<Box<dyn Client>> { pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config) init_client(&self.config, Some(self.model()))
} }
pub fn prepare_send_data(&self, stream: bool) -> Result<SendData> { pub fn prepare_send_data(&self, stream: bool) -> Result<SendData> {

@ -54,7 +54,8 @@ const RIGHT_PROMPT: &str = "{color.purple}{?session {?consume_tokens {consume_to
#[serde(default)] #[serde(default)]
pub struct Config { pub struct Config {
#[serde(rename(serialize = "model", deserialize = "model"))] #[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>, #[serde(default)]
pub model_id: String,
pub temperature: Option<f64>, pub temperature: Option<f64>,
pub top_p: Option<f64>, pub top_p: Option<f64>,
pub dry_run: bool, pub dry_run: bool,
@ -91,7 +92,7 @@ pub struct Config {
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
Self { Self {
model_id: None, model_id: Default::default(),
temperature: None, temperature: None,
top_p: None, top_p: None,
save: false, save: false,
@ -296,12 +297,16 @@ impl Config {
session.set_temperature(role.temperature); session.set_temperature(role.temperature);
session.set_top_p(role.top_p); session.set_top_p(role.top_p);
} }
if let Some(model_id) = &role.model_id {
self.set_model(model_id)?;
}
self.role = Some(role); self.role = Some(role);
Ok(()) Ok(())
} }
pub fn clear_role(&mut self) -> Result<()> { pub fn clear_role(&mut self) -> Result<()> {
self.role = None; self.role = None;
self.restore_model()?;
Ok(()) Ok(())
} }
@ -381,6 +386,8 @@ impl Config {
Some(model) => { Some(model) => {
if let Some(session) = self.session.as_mut() { if let Some(session) = self.session.as_mut() {
session.set_model(&model); session.set_model(&model);
} else if let Some(role) = self.role.as_mut() {
role.set_model(&model);
} }
self.model = model; self.model = model;
Ok(()) Ok(())
@ -388,12 +395,28 @@ impl Config {
} }
} }
pub fn set_model_id(&mut self) {
self.model_id = self.model.id()
}
pub fn restore_model(&mut self) -> Result<()> {
let origin_model_id = self.model_id.clone();
self.set_model(&origin_model_id)
}
pub fn system_info(&self) -> Result<String> { pub fn system_info(&self) -> Result<String> {
let display_path = |path: &Path| path.display().to_string(); let display_path = |path: &Path| path.display().to_string();
let wrap = self let wrap = self
.wrap .wrap
.clone() .clone()
.map_or_else(|| String::from("no"), |v| v.to_string()); .map_or_else(|| String::from("no"), |v| v.to_string());
let (temperature, top_p) = if let Some(session) = &self.session {
(session.temperature(), session.top_p())
} else if let Some(role) = &self.role {
(role.temperature, role.top_p)
} else {
(self.temperature, self.top_p)
};
let items = vec![ let items = vec![
("model", self.model.id()), ("model", self.model.id()),
( (
@ -403,8 +426,8 @@ impl Config {
.map(|v| format!("{v} (current model)")) .map(|v| format!("{v} (current model)"))
.unwrap_or_else(|| "-".into()), .unwrap_or_else(|| "-".into()),
), ),
("temperature", format_option_value(&self.temperature)), ("temperature", format_option_value(&temperature)),
("top_p", format_option_value(&self.top_p)), ("top_p", format_option_value(&top_p)),
("dry_run", self.dry_run.to_string()), ("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()), ("save", self.save.to_string()),
("save_session", format_option_value(&self.save_session)), ("save_session", format_option_value(&self.save_session)),
@ -645,6 +668,7 @@ impl Config {
} }
Self::save_session_to_file(&mut session)?; Self::save_session_to_file(&mut session)?;
} }
self.restore_model()?;
} }
Ok(()) Ok(())
} }
@ -926,18 +950,19 @@ impl Config {
} }
fn setup_model(&mut self) -> Result<()> { fn setup_model(&mut self) -> Result<()> {
let model = match &self.model_id { let model_id = if self.model_id.is_empty() {
Some(v) => v.clone(),
None => {
let models = list_models(self); let models = list_models(self);
if models.is_empty() { if models.is_empty() {
bail!("No available model"); bail!("No available model");
} }
models[0].id() let model_id = models[0].id();
} self.model_id.clone_from(&model_id);
model_id
} else {
self.model_id.clone()
}; };
self.set_model(&model)?; self.set_model(&model_id)?;
Ok(()) Ok(())
} }
@ -1046,6 +1071,10 @@ impl State {
pub fn in_role() -> Vec<Self> { pub fn in_role() -> Vec<Self> {
vec![Self::Role, Self::EmptySessionWithRole] vec![Self::Role, Self::EmptySessionWithRole]
} }
pub fn is_normal(&self) -> bool {
self == &Self::Normal
}
} }
fn create_config_file(config_path: &Path) -> Result<()> { fn create_config_file(config_path: &Path) -> Result<()> {

@ -1,6 +1,6 @@
use super::Input; use super::Input;
use crate::{ use crate::{
client::{Message, MessageContent, MessageRole}, client::{Message, MessageContent, MessageRole, Model},
utils::{detect_os, detect_shell}, utils::{detect_os, detect_shell},
}; };
@ -18,6 +18,8 @@ pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
pub struct Role { pub struct Role {
pub name: String, pub name: String,
pub prompt: String, pub prompt: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
pub temperature: Option<f64>, pub temperature: Option<f64>,
pub top_p: Option<f64>, pub top_p: Option<f64>,
} }
@ -28,6 +30,7 @@ impl Role {
name: TEMP_ROLE.into(), name: TEMP_ROLE.into(),
prompt: prompt.into(), prompt: prompt.into(),
temperature: None, temperature: None,
model_id: None,
top_p: None, top_p: None,
} }
} }
@ -62,6 +65,7 @@ async function timeout(ms) {
.map(|(name, prompt)| Self { .map(|(name, prompt)| Self {
name: name.into(), name: name.into(),
prompt, prompt,
model_id: None,
temperature: None, temperature: None,
top_p: None, top_p: None,
}) })
@ -78,6 +82,10 @@ async function timeout(ms) {
self.prompt.contains(INPUT_PLACEHOLDER) self.prompt.contains(INPUT_PLACEHOLDER)
} }
pub fn set_model(&mut self, model: &Model) {
self.model_id = Some(model.id());
}
pub fn set_temperature(&mut self, value: Option<f64>) { pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value; self.temperature = value;
} }

@ -97,6 +97,7 @@ async fn main() -> Result<()> {
} }
if let Some(model) = &cli.model { if let Some(model) = &cli.model {
config.write().set_model(model)?; config.write().set_model(model)?;
config.write().set_model_id();
} }
if cli.save_session { if cli.save_session {
config.write().set_save_session(Some(true)); config.write().set_save_session(Some(true));

@ -163,6 +163,9 @@ impl Repl {
".model" => match args { ".model" => match args {
Some(name) => { Some(name) => {
self.config.write().set_model(name)?; self.config.write().set_model(name)?;
if self.config.read().state().is_normal() {
self.config.write().set_model_id();
}
} }
None => println!("Usage: .model <name>"), None => println!("Usage: .model <name>"),
}, },

@ -249,7 +249,7 @@ impl Server {
config.write().set_model(&model_name)?; config.write().set_model(&model_name)?;
} }
let mut client = init_client(&config)?; let mut client = init_client(&config, None)?;
if max_tokens.is_some() { if max_tokens.is_some() {
client.model_mut().set_max_tokens(max_tokens, true); client.model_mut().set_max_tokens(max_tokens, true);
} }

Loading…
Cancel
Save