diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index e553f3f..3d0cf3f 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,4 +1,4 @@ -use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; +use super::openai::openai_build_body; use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; @@ -45,7 +45,6 @@ impl AzureOpenAIClient { Model::new(client_name, &v.name) .set_max_input_tokens(v.max_input_tokens) .set_capabilities(v.capabilities) - .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() } @@ -54,7 +53,8 @@ impl AzureOpenAIClient { let api_base = self.get_api_base()?; let api_key = self.get_api_key()?; - let body = openai_build_body(data, self.model.name.clone()); + let mut body = openai_build_body(data, self.model.name.clone()); + self.model.merge_extra_fields(&mut body); let url = format!( "{}/openai/deployments/{}/chat/completions?api-version=2023-05-15", diff --git a/src/client/claude.rs b/src/client/claude.rs index 1eb83d1..0acbf72 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,7 +1,4 @@ -use super::{ - patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData, - TokensCountFactors, -}; +use super::{patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData}; use crate::{ client::{ImageUrl, MessageContent, MessageContentPart}, @@ -26,8 +23,6 @@ const MODELS: [(&str, usize, &str); 3] = [ ("claude-3-haiku-20240307", 200000, "text,vision"), ]; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - #[derive(Debug, Clone, Deserialize)] pub struct ClaudeConfig { pub name: Option, @@ -69,7 +64,6 @@ impl ClaudeClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() } diff --git a/src/client/cohere.rs b/src/client/cohere.rs index a93dc71..6f2e288 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,6 +1,6 @@ use super::{ json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, - PromptType, SendData, TokensCountFactors, + PromptType, SendData, }; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -19,8 +19,6 @@ const MODELS: [(&str, usize, &str); 2] = [ ("command-r-plus", 128000, "text"), ]; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - #[derive(Debug, Clone, Deserialize, Default)] pub struct CohereConfig { pub name: Option, @@ -62,7 +60,6 @@ impl CohereClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() } @@ -70,8 +67,7 @@ impl CohereClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let mut body = build_body(data, self.model.name.clone())?; - self.model.merge_extra_fields(&mut body); + let body = build_body(data, self.model.name.clone())?; let url = API_URL; diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 19e656a..bcd10e3 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,5 +1,5 @@ use super::vertexai::{build_body, send_message, send_message_streaming}; -use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData, TokensCountFactors}; +use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData}; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -17,8 +17,6 @@ const MODELS: [(&str, usize, &str); 3] = [ ("gemini-1.5-pro-latest", 1048576, "text,vision"), ]; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - #[derive(Debug, Clone, Deserialize, Default)] pub struct GeminiConfig { pub name: Option, @@ -61,7 +59,6 @@ impl GeminiClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() } diff --git a/src/client/mistral.rs b/src/client/mistral.rs index 0ebb631..611c25c 100644 --- a/src/client/mistral.rs +++ b/src/client/mistral.rs @@ -1,4 +1,4 @@ -use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; +use super::openai::openai_build_body; use super::{ExtraConfig, MistralClient, Model, PromptType, SendData}; use crate::utils::PromptKind; @@ -44,7 +44,6 @@ impl MistralClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() } @@ -52,8 +51,7 @@ impl MistralClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let mut body = openai_build_body(data, self.model.name.clone()); - self.model.merge_extra_fields(&mut body); + let body = openai_build_body(data, self.model.name.clone()); let url = API_URL; diff --git a/src/client/model.rs b/src/client/model.rs index ce181a5..03877b9 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -5,7 +5,8 @@ use crate::utils::count_tokens; use anyhow::{bail, Result}; use serde::{Deserialize, Deserializer}; -pub type TokensCountFactors = (usize, usize); // (per-messages, bias) +const PER_MESSAGES_TOKENS: usize = 5; +const BASIS_TOKENS: usize = 2; #[derive(Debug, Clone)] pub struct Model { @@ -13,7 +14,6 @@ pub struct Model { pub name: String, pub max_input_tokens: Option, pub extra_fields: Option>, - pub tokens_count_factors: TokensCountFactors, pub capabilities: ModelCapabilities, } @@ -30,7 +30,6 @@ impl Model { name: name.into(), extra_fields: None, max_input_tokens: None, - tokens_count_factors: Default::default(), capabilities: ModelCapabilities::Text, } } @@ -91,11 +90,6 @@ impl Model { self } - pub fn set_tokens_count_factors(mut self, tokens_count_factors: TokensCountFactors) -> Self { - self.tokens_count_factors = tokens_count_factors; - self - } - pub fn messages_tokens(&self, messages: &[Message]) -> usize { messages .iter() @@ -114,17 +108,15 @@ impl Model { } let num_messages = messages.len(); let message_tokens = self.messages_tokens(messages); - let (per_messages, _) = self.tokens_count_factors; if messages[num_messages - 1].role.is_user() { - num_messages * per_messages + message_tokens + num_messages * PER_MESSAGES_TOKENS + message_tokens } else { - (num_messages - 1) * per_messages + message_tokens + (num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens } } pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> { - let (_, bias) = self.tokens_count_factors; - let total_tokens = self.total_tokens(messages) + bias; + let total_tokens = self.total_tokens(messages) + BASIS_TOKENS; if let Some(max_input_tokens) = self.max_input_tokens { if total_tokens >= max_input_tokens { bail!("Exceed max input tokens limit") diff --git a/src/client/moonshot.rs b/src/client/moonshot.rs index a105e01..18cf0f2 100644 --- a/src/client/moonshot.rs +++ b/src/client/moonshot.rs @@ -1,4 +1,4 @@ -use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; +use super::openai::openai_build_body; use super::{ExtraConfig, MoonshotClient, Model, PromptType, SendData}; use crate::utils::PromptKind; @@ -42,7 +42,6 @@ impl MoonshotClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() } @@ -50,8 +49,7 @@ impl MoonshotClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let mut body = openai_build_body(data, self.model.name.clone()); - self.model.merge_extra_fields(&mut body); + let body = openai_build_body(data, self.model.name.clone()); let url = API_URL; diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 5ecf6c8..2c51f44 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,6 +1,6 @@ use super::{ message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient, - PromptType, SendData, TokensCountFactors, + PromptType, SendData, }; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -12,8 +12,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - #[derive(Debug, Clone, Deserialize, Default)] pub struct OllamaConfig { pub name: Option, @@ -70,7 +68,6 @@ impl OllamaClient { .set_capabilities(v.capabilities) .set_max_input_tokens(v.max_input_tokens) .set_extra_fields(v.extra_fields.clone()) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() } @@ -79,7 +76,6 @@ impl OllamaClient { let api_key = self.get_api_key().ok(); let mut body = build_body(data, self.model.name.clone())?; - self.model.merge_extra_fields(&mut body); let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat"); diff --git a/src/client/openai.rs b/src/client/openai.rs index 56cc53c..c996ed0 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,4 +1,4 @@ -use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData, TokensCountFactors}; +use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData}; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -24,8 +24,6 @@ const MODELS: [(&str, usize, &str); 8] = [ ("gpt-4-32k", 32768, "text"), ]; -pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - #[derive(Debug, Clone, Deserialize, Default)] pub struct OpenAIConfig { pub name: Option, @@ -52,7 +50,6 @@ impl OpenAIClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() } diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index ec36623..595b2cd 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -1,4 +1,4 @@ -use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; +use super::openai::openai_build_body; use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData}; use crate::utils::PromptKind; @@ -47,7 +47,6 @@ impl OpenAICompatibleClient { .set_capabilities(v.capabilities) .set_max_input_tokens(v.max_input_tokens) .set_extra_fields(v.extra_fields.clone()) - .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() } diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 3e69dae..031abe7 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,5 +1,5 @@ use super::{ - message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData, TokensCountFactors, + message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData, }; use crate::{ @@ -37,8 +37,6 @@ const MODELS: [(&str, usize, &str); 6] = [ ("qwen-vl-max", 0, "text,vision"), ]; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (4, 14); - #[derive(Debug, Clone, Deserialize, Default)] pub struct QianwenConfig { pub name: Option, @@ -88,7 +86,6 @@ impl QianwenClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() } diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index babbd23..ad39c16 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,6 +1,6 @@ use super::{ json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, - SendData, TokensCountFactors, VertexAIClient, + SendData, VertexAIClient, }; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -20,8 +20,6 @@ const MODELS: [(&str, usize, &str); 3] = [ ("gemini-1.5-pro-preview-0409", 1000000, "text,vision"), ]; -const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); - static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation #[derive(Debug, Clone, Deserialize, Default)] @@ -69,7 +67,6 @@ impl VertexAIClient { Model::new(client_name, name) .set_capabilities(capabilities.into()) .set_max_input_tokens(Some(max_input_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() }