refactor: all clients use openai token counter (#402)

pull/403/head
sigoden 2 months ago committed by GitHub
parent 01ebc87348
commit 3b5843fe2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,4 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -45,7 +45,6 @@ impl AzureOpenAIClient {
Model::new(client_name, &v.name)
.set_max_input_tokens(v.max_input_tokens)
.set_capabilities(v.capabilities)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}
@ -54,7 +53,8 @@ impl AzureOpenAIClient {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
let body = openai_build_body(data, self.model.name.clone());
let mut body = openai_build_body(data, self.model.name.clone());
self.model.merge_extra_fields(&mut body);
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",

@ -1,7 +1,4 @@
use super::{
patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData,
TokensCountFactors,
};
use super::{patch_system_message, ClaudeClient, Client, ExtraConfig, Model, PromptType, SendData};
use crate::{
client::{ImageUrl, MessageContent, MessageContentPart},
@ -26,8 +23,6 @@ const MODELS: [(&str, usize, &str); 3] = [
("claude-3-haiku-20240307", 200000, "text,vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
@ -69,7 +64,6 @@ impl ClaudeClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}

@ -1,6 +1,6 @@
use super::{
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
PromptType, SendData, TokensCountFactors,
PromptType, SendData,
};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -19,8 +19,6 @@ const MODELS: [(&str, usize, &str); 2] = [
("command-r-plus", 128000, "text"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct CohereConfig {
pub name: Option<String>,
@ -62,7 +60,6 @@ impl CohereClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}
@ -70,8 +67,7 @@ impl CohereClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = build_body(data, self.model.name.clone())?;
self.model.merge_extra_fields(&mut body);
let body = build_body(data, self.model.name.clone())?;
let url = API_URL;

@ -1,5 +1,5 @@
use super::vertexai::{build_body, send_message, send_message_streaming};
use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData, TokensCountFactors};
use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, SendData};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -17,8 +17,6 @@ const MODELS: [(&str, usize, &str); 3] = [
("gemini-1.5-pro-latest", 1048576, "text,vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
pub name: Option<String>,
@ -61,7 +59,6 @@ impl GeminiClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}

@ -1,4 +1,4 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::openai::openai_build_body;
use super::{ExtraConfig, MistralClient, Model, PromptType, SendData};
use crate::utils::PromptKind;
@ -44,7 +44,6 @@ impl MistralClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}
@ -52,8 +51,7 @@ impl MistralClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, self.model.name.clone());
self.model.merge_extra_fields(&mut body);
let body = openai_build_body(data, self.model.name.clone());
let url = API_URL;

@ -5,7 +5,8 @@ use crate::utils::count_tokens;
use anyhow::{bail, Result};
use serde::{Deserialize, Deserializer};
pub type TokensCountFactors = (usize, usize); // (per-messages, bias)
const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
#[derive(Debug, Clone)]
pub struct Model {
@ -13,7 +14,6 @@ pub struct Model {
pub name: String,
pub max_input_tokens: Option<usize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub tokens_count_factors: TokensCountFactors,
pub capabilities: ModelCapabilities,
}
@ -30,7 +30,6 @@ impl Model {
name: name.into(),
extra_fields: None,
max_input_tokens: None,
tokens_count_factors: Default::default(),
capabilities: ModelCapabilities::Text,
}
}
@ -91,11 +90,6 @@ impl Model {
self
}
pub fn set_tokens_count_factors(mut self, tokens_count_factors: TokensCountFactors) -> Self {
self.tokens_count_factors = tokens_count_factors;
self
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages
.iter()
@ -114,17 +108,15 @@ impl Model {
}
let num_messages = messages.len();
let message_tokens = self.messages_tokens(messages);
let (per_messages, _) = self.tokens_count_factors;
if messages[num_messages - 1].role.is_user() {
num_messages * per_messages + message_tokens
num_messages * PER_MESSAGES_TOKENS + message_tokens
} else {
(num_messages - 1) * per_messages + message_tokens
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
}
}
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
let (_, bias) = self.tokens_count_factors;
let total_tokens = self.total_tokens(messages) + bias;
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
if let Some(max_input_tokens) = self.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max input tokens limit")

@ -1,4 +1,4 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::openai::openai_build_body;
use super::{ExtraConfig, MoonshotClient, Model, PromptType, SendData};
use crate::utils::PromptKind;
@ -42,7 +42,6 @@ impl MoonshotClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}
@ -50,8 +49,7 @@ impl MoonshotClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, self.model.name.clone());
self.model.merge_extra_fields(&mut body);
let body = openai_build_body(data, self.model.name.clone());
let url = API_URL;

@ -1,6 +1,6 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, SendData, TokensCountFactors,
PromptType, SendData,
};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -12,8 +12,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OllamaConfig {
pub name: Option<String>,
@ -70,7 +68,6 @@ impl OllamaClient {
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}
@ -79,7 +76,6 @@ impl OllamaClient {
let api_key = self.get_api_key().ok();
let mut body = build_body(data, self.model.name.clone())?;
self.model.merge_extra_fields(&mut body);
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");

@ -1,4 +1,4 @@
use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData, TokensCountFactors};
use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -24,8 +24,6 @@ const MODELS: [(&str, usize, &str); 8] = [
("gpt-4-32k", 32768, "text"),
];
pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
pub name: Option<String>,
@ -52,7 +50,6 @@ impl OpenAIClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}

@ -1,4 +1,4 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::openai::openai_build_body;
use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData};
use crate::utils::PromptKind;
@ -47,7 +47,6 @@ impl OpenAICompatibleClient {
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}

@ -1,5 +1,5 @@
use super::{
message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData, TokensCountFactors,
message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData,
};
use crate::{
@ -37,8 +37,6 @@ const MODELS: [(&str, usize, &str); 6] = [
("qwen-vl-max", 0, "text,vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (4, 14);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct QianwenConfig {
pub name: Option<String>,
@ -88,7 +86,6 @@ impl QianwenClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}

@ -1,6 +1,6 @@
use super::{
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType,
SendData, TokensCountFactors, VertexAIClient,
SendData, VertexAIClient,
};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -20,8 +20,6 @@ const MODELS: [(&str, usize, &str); 3] = [
("gemini-1.5-pro-preview-0409", 1000000, "text,vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
@ -69,7 +67,6 @@ impl VertexAIClient {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}

Loading…
Cancel
Save