From fe35cfd9419302f01baf9672493c0b0a4b41d889 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 13 Jan 2024 19:52:07 +0800 Subject: [PATCH] feat: supports model capabilities (#297) 1. automatically switch to the model that has the necessary capabilities. 2. throw an error if the client does not have a model with the necessary capabilities --- Cargo.lock | 1 + Cargo.toml | 1 + config.example.yaml | 7 ++-- src/client/azure_openai.rs | 11 ++---- src/client/common.rs | 69 +++++++++++++++++++++++++++++--------- src/client/ernie.rs | 32 ++++++------------ src/client/gemini.rs | 17 +++++----- src/client/localai.rs | 11 ++---- src/client/model.rs | 51 ++++++++++++++++++++++++++++ src/client/ollama.rs | 19 ++++------- src/client/openai.rs | 29 +++++++--------- src/client/qianwen.rs | 23 ++++++------- src/config/input.rs | 10 +++++- src/main.rs | 5 +-- src/repl/mod.rs | 5 +-- 15 files changed, 181 insertions(+), 110 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75f07a7..c8d10ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,7 @@ dependencies = [ "async-trait", "base64", "bincode", + "bitflags 2.4.1", "bstr", "bytes", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 7be3654..1626da4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ log = "0.4.20" shell-words = "1.1.0" mime_guess = "2.0.4" sha2 = "0.10.8" +bitflags = "2.4.1" [dependencies.reqwest] version = "0.11.14" diff --git a/config.example.yaml b/config.example.yaml index d360a7b..ce12cce 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -36,8 +36,11 @@ clients: api_key: xxx chat_endpoint: /chat/completions # Optional field models: - - name: gpt4all-j + - name: mistral max_tokens: 8192 + - name: llava + max_tokens: 8192 + capabilities: text,vision # Optional field, possible values: text, vision # See https://github.com/jmorganca/ollama - type: ollama @@ -45,7 +48,7 @@ clients: api_key: Basic xxx # Set authorization header chat_endpoint: /chat # Optional field models: - - name: gpt4all-j + - name: mistral max_tokens: 8192 # See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 4700ad7..5c9f3ef 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,5 +1,5 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, Model}; +use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; @@ -13,16 +13,10 @@ pub struct AzureOpenAIConfig { pub name: Option, pub api_base: Option, pub api_key: Option, - pub models: Vec, + pub models: Vec, pub extra: Option, } -#[derive(Debug, Clone, Deserialize)] -pub struct AzureOpenAIModel { - name: String, - max_tokens: Option, -} - openai_compatible_client!(AzureOpenAIClient); impl AzureOpenAIClient { @@ -50,6 +44,7 @@ impl AzureOpenAIClient { .map(|v| { Model::new(client_name, &v.name) .set_max_tokens(v.max_tokens) + .set_capabilities(v.capabilities) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() diff --git a/src/client/common.rs b/src/client/common.rs index 9ff02ba..c35ba1b 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,4 +1,4 @@ -use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent}; +use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model}; use crate::{ config::{GlobalConfig, Input}, @@ -78,12 +78,26 @@ macro_rules! register_client { )+ pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result> { - None - $(.or_else(|| $client::init(config)))+ - .ok_or_else(|| { - let model = config.read().model.clone(); - anyhow::anyhow!("Unknown client '{}'", &model.client_name) - }) + None + $(.or_else(|| $client::init(config)))+ + .ok_or_else(|| { + let model = config.read().model.clone(); + anyhow::anyhow!("Unknown client '{}'", &model.client_name) + }) + } + + pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> { + if !client.model().capabilities.contains(capabilities) { + let models = client.models(); + if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) { + client.set_model(model); + } else { + anyhow::bail!( + "The current model lacks the corresponding capability." + ); + } + } + Ok(()) } pub fn list_client_types() -> Vec<&'static str> { @@ -113,19 +127,38 @@ macro_rules! register_client { }; } +#[macro_export] +macro_rules! client_common_fns { + () => { + fn config( + &self, + ) -> ( + &$crate::config::GlobalConfig, + &Option<$crate::client::ExtraConfig>, + ) { + (&self.global_config, &self.config.extra) + } + + fn models(&self) -> Vec { + Self::list_models(&self.config) + } + + fn model(&self) -> &Model { + &self.model + } + + fn set_model(&mut self, model: Model) { + self.model = model; + } + }; +} + #[macro_export] macro_rules! openai_compatible_client { ($client:ident) => { #[async_trait] impl $crate::client::Client for $crate::client::$client { - fn config( - &self, - ) -> ( - &$crate::config::GlobalConfig, - &Option<$crate::client::ExtraConfig>, - ) { - (&self.global_config, &self.config.extra) - } + client_common_fns!(); async fn send_message_inner( &self, @@ -170,6 +203,12 @@ macro_rules! config_get_fn { pub trait Client { fn config(&self) -> (&GlobalConfig, &Option); + fn models(&self) -> Vec; + + fn model(&self) -> &Model; + + fn set_model(&mut self, model: Model); + fn build_client(&self) -> Result { let mut builder = ReqwestClient::builder(); let options = self.config().1; diff --git a/src/client/ernie.rs b/src/client/ernie.rs index d7e3a57..7848cf8 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,10 +1,6 @@ -use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model, patch_system_message}; +use super::{patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, SendData}; -use crate::{ - config::GlobalConfig, - render::ReplyHandler, - utils::PromptKind, -}; +use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; @@ -37,9 +33,7 @@ pub struct ErnieConfig { #[async_trait] impl Client for ErnieClient { - fn config(&self) -> (&GlobalConfig, &Option) { - (&self.global_config, &self.config.extra) - } + client_common_fns!(); async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { self.prepare_access_token().await?; @@ -127,10 +121,7 @@ async fn send_message(builder: RequestBuilder) -> Result { Ok(output.to_string()) } -async fn send_message_streaming( - builder: RequestBuilder, - handler: &mut ReplyHandler, -) -> Result<()> { +async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { match event { @@ -216,13 +207,12 @@ fn build_body(data: SendData, _model: String) -> Value { async fn fetch_access_token(api_key: &str, secret_key: &str) -> Result { let url = format!("{ACCESS_TOKEN_URL}?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"); let value: Value = reqwest::get(&url).await?.json().await?; - let result = value["access_token"].as_str() - .ok_or_else(|| { - if let Some(err_msg) = value["error_description"].as_str() { - anyhow!("{err_msg}") - } else { - anyhow!("Invalid response data") - } - })?; + let result = value["access_token"].as_str().ok_or_else(|| { + if let Some(err_msg) = value["error_description"].as_str() { + anyhow!("{err_msg}") + } else { + anyhow!("Invalid response data") + } + })?; Ok(result.to_string()) } diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 98fd60d..6ff3f95 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -3,7 +3,7 @@ use super::{ SendData, TokensCountFactors, }; -use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; +use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -14,10 +14,10 @@ use serde_json::{json, Value}; const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/"; -const MODELS: [(&str, usize); 3] = [ - ("gemini-pro", 32768), - ("gemini-pro-vision", 16384), - ("gemini-ultra", 32768), +const MODELS: [(&str, usize, &str); 3] = [ + ("gemini-pro", 32768, "text"), + ("gemini-pro-vision", 16384, "vision"), + ("gemini-ultra", 32768, "text"), ]; const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); @@ -31,9 +31,7 @@ pub struct GeminiConfig { #[async_trait] impl Client for GeminiClient { - fn config(&self) -> (&GlobalConfig, &Option) { - (&self.global_config, &self.config.extra) - } + client_common_fns!(); async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { let builder = self.request_builder(client, data)?; @@ -61,8 +59,9 @@ impl GeminiClient { let client_name = Self::name(local_config); MODELS .into_iter() - .map(|(name, max_tokens)| { + .map(|(name, max_tokens, capabilities)| { Model::new(client_name, name) + .set_capabilities(capabilities.into()) .set_max_tokens(Some(max_tokens)) .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) diff --git a/src/client/localai.rs b/src/client/localai.rs index 9325e0f..3bc0670 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,5 +1,5 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{ExtraConfig, LocalAIClient, PromptType, SendData, Model}; +use super::{ExtraConfig, LocalAIClient, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; @@ -14,16 +14,10 @@ pub struct LocalAIConfig { pub api_base: String, pub api_key: Option, pub chat_endpoint: Option, - pub models: Vec, + pub models: Vec, pub extra: Option, } -#[derive(Debug, Clone, Deserialize)] -pub struct LocalAIModel { - name: String, - max_tokens: Option, -} - openai_compatible_client!(LocalAIClient); impl LocalAIClient { @@ -49,6 +43,7 @@ impl LocalAIClient { .iter() .map(|v| { Model::new(client_name, &v.name) + .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) diff --git a/src/client/model.rs b/src/client/model.rs index 130489d..88d3294 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -3,6 +3,7 @@ use super::message::{Message, MessageContent}; use crate::utils::count_tokens; use anyhow::{bail, Result}; +use serde::{Deserialize, Deserializer}; pub type TokensCountFactors = (usize, usize); // (per-messages, bias) @@ -12,6 +13,7 @@ pub struct Model { pub name: String, pub max_tokens: Option, pub tokens_count_factors: TokensCountFactors, + pub capabilities: ModelCapabilities, } impl Default for Model { @@ -27,6 +29,7 @@ impl Model { name: name.into(), max_tokens: None, tokens_count_factors: Default::default(), + capabilities: ModelCapabilities::Text, } } @@ -65,6 +68,11 @@ impl Model { format!("{}:{}", self.client_name, self.name) } + pub fn set_capabilities(mut self, capabilities: ModelCapabilities) -> Self { + self.capabilities = capabilities; + self + } + pub fn set_max_tokens(mut self, max_tokens: Option) -> Self { match max_tokens { None | Some(0) => self.max_tokens = None, @@ -115,3 +123,46 @@ impl Model { Ok(()) } } + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelConfig { + pub name: String, + pub max_tokens: Option, + #[serde(deserialize_with = "deserialize_capabilities")] + #[serde(default = "default_capabilities")] + pub capabilities: ModelCapabilities, +} + +bitflags::bitflags! { + #[derive(Debug, Clone, Copy, PartialEq)] + pub struct ModelCapabilities: u32 { + const Text = 0b00000001; + const Vision = 0b00000010; + } +} + +impl From<&str> for ModelCapabilities { + fn from(value: &str) -> Self { + let value = if value.is_empty() { "text" } else { value }; + let mut output = ModelCapabilities::empty(); + if value.contains("text") { + output |= ModelCapabilities::Text; + } + if value.contains("vision") { + output |= ModelCapabilities::Vision; + } + output + } +} + +fn deserialize_capabilities<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value: String = Deserialize::deserialize(deserializer)?; + Ok(value.as_str().into()) +} + +fn default_capabilities() -> ModelCapabilities { + ModelCapabilities::Text +} diff --git a/src/client/ollama.rs b/src/client/ollama.rs index fc6e148..0705f2e 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,9 +1,9 @@ use super::{ - message::*, patch_system_message, Client, ExtraConfig, Model, OllamaClient, PromptType, - SendData, TokensCountFactors, + message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient, + PromptType, SendData, TokensCountFactors, }; -use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; +use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -20,21 +20,13 @@ pub struct OllamaConfig { pub api_base: String, pub api_key: Option, pub chat_endpoint: Option, - pub models: Vec, + pub models: Vec, pub extra: Option, } -#[derive(Debug, Clone, Deserialize)] -pub struct LocalAIModel { - name: String, - max_tokens: Option, -} - #[async_trait] impl Client for OllamaClient { - fn config(&self) -> (&GlobalConfig, &Option) { - (&self.global_config, &self.config.extra) - } + client_common_fns!(); async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { let builder = self.request_builder(client, data)?; @@ -75,6 +67,7 @@ impl OllamaClient { .iter() .map(|v| { Model::new(client_name, &v.name) + .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) diff --git a/src/client/openai.rs b/src/client/openai.rs index b57c175..2a480f3 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,12 +1,6 @@ -use super::{ - ExtraConfig, OpenAIClient, PromptType, SendData, - Model, TokensCountFactors, -}; +use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData, TokensCountFactors}; -use crate::{ - render::ReplyHandler, - utils::PromptKind, -}; +use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -19,14 +13,14 @@ use std::env; const API_BASE: &str = "https://api.openai.com/v1"; -const MODELS: [(&str, usize); 7] = [ - ("gpt-3.5-turbo", 4096), - ("gpt-3.5-turbo-16k", 16385), - ("gpt-3.5-turbo-1106", 16385), - ("gpt-4", 8192), - ("gpt-4-32k", 32768), - ("gpt-4-1106-preview", 128000), - ("gpt-4-vision-preview", 128000), +const MODELS: [(&str, usize, &str); 7] = [ + ("gpt-3.5-turbo", 4096, "text"), + ("gpt-3.5-turbo-16k", 16385, "text"), + ("gpt-3.5-turbo-1106", 16385, "text"), + ("gpt-4", 8192, "text"), + ("gpt-4-32k", 32768, "text"), + ("gpt-4-1106-preview", 128000, "text"), + ("gpt-4-vision-preview", 128000, "text,vision"), ]; pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); @@ -51,8 +45,9 @@ impl OpenAIClient { let client_name = Self::name(local_config); MODELS .into_iter() - .map(|(name, max_tokens)| { + .map(|(name, max_tokens, capabilities)| { Model::new(client_name, name) + .set_capabilities(capabilities.into()) .set_max_tokens(Some(max_tokens)) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 5b8755f..b9254ff 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,7 +1,6 @@ use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData}; use crate::{ - config::GlobalConfig, render::ReplyHandler, utils::{sha256sum, PromptKind}, }; @@ -25,12 +24,12 @@ const API_URL: &str = const API_URL_VL: &str = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"; -const MODELS: [(&str, usize); 5] = [ - ("qwen-turbo", 8192), - ("qwen-plus", 32768), - ("qwen-max", 8192), - ("qwen-max-longcontext", 30720), - ("qwen-vl-plus", 0), +const MODELS: [(&str, usize, &str); 5] = [ + ("qwen-turbo", 8192, "text"), + ("qwen-plus", 32768, "text"), + ("qwen-max", 8192, "text"), + ("qwen-max-longcontext", 30720, "text"), + ("qwen-vl-plus", 0, "text,vision"), ]; #[derive(Debug, Clone, Deserialize, Default)] @@ -42,9 +41,7 @@ pub struct QianwenConfig { #[async_trait] impl Client for QianwenClient { - fn config(&self) -> (&GlobalConfig, &Option) { - (&self.global_config, &self.config.extra) - } + client_common_fns!(); async fn send_message_inner( &self, @@ -80,8 +77,10 @@ impl QianwenClient { let client_name = Self::name(local_config); MODELS .into_iter() - .map(|(name, max_tokens)| { - Model::new(client_name, name).set_max_tokens(Some(max_tokens)) + .map(|(name, max_tokens, capabilities)| { + Model::new(client_name, name) + .set_capabilities(capabilities.into()) + .set_max_tokens(Some(max_tokens)) }) .collect() } diff --git a/src/config/input.rs b/src/config/input.rs index fc29c49..c3e0de3 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -1,4 +1,4 @@ -use crate::client::{ImageUrl, MessageContent, MessageContentPart}; +use crate::client::{ImageUrl, MessageContent, MessageContentPart, ModelCapabilities}; use crate::utils::sha256sum; use anyhow::{bail, Context, Result}; @@ -119,6 +119,14 @@ impl Input { MessageContent::Array(list) } } + + pub fn required_capabilities(&self) -> ModelCapabilities { + if !self.medias.is_empty() { + ModelCapabilities::Vision + } else { + ModelCapabilities::Text + } + } } pub fn resolve_data_url(data_urls: &HashMap, data_url: String) -> String { diff --git a/src/main.rs b/src/main.rs index 4d199aa..5d8315d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use crate::config::{Config, GlobalConfig}; use anyhow::Result; use clap::Parser; -use client::{init_client, list_models}; +use client::{ensure_model_capabilities, init_client, list_models}; use config::Input; use is_terminal::IsTerminal; use parking_lot::RwLock; @@ -114,7 +114,8 @@ fn start_directive( session.guard_save()?; } let input = Input::new(text, include.unwrap_or_default())?; - let client = init_client(config)?; + let mut client = init_client(config)?; + ensure_model_capabilities(client.as_mut(), input.required_capabilities())?; config.read().maybe_print_send_tokens(&input); let output = if no_stream { let output = client.send_message(input.clone())?; diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 719ef3a..832ffb4 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -6,7 +6,7 @@ use self::completer::ReplCompleter; use self::highlighter::ReplHighlighter; use self::prompt::ReplPrompt; -use crate::client::init_client; +use crate::client::{ensure_model_capabilities, init_client}; use crate::config::{GlobalConfig, Input, State}; use crate::render::{render_error, render_stream}; use crate::utils::{create_abort_signal, set_text, AbortSignal}; @@ -268,7 +268,8 @@ impl Repl { Input::new(text, files)? }; self.config.read().maybe_print_send_tokens(&input); - let client = init_client(&self.config)?; + let mut client = init_client(&self.config)?; + ensure_model_capabilities(client.as_mut(), input.required_capabilities())?; let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?; self.config.write().save_message(input, &output)?; if self.config.read().auto_copy {