|
|
|
@ -5,7 +5,8 @@ use crate::utils::count_tokens;
|
|
|
|
|
use anyhow::{bail, Result};
|
|
|
|
|
use serde::{Deserialize, Deserializer};
|
|
|
|
|
|
|
|
|
|
pub type TokensCountFactors = (usize, usize); // (per-messages, bias)
|
|
|
|
|
const PER_MESSAGES_TOKENS: usize = 5;
|
|
|
|
|
const BASIS_TOKENS: usize = 2;
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
|
pub struct Model {
|
|
|
|
@ -13,7 +14,6 @@ pub struct Model {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub max_input_tokens: Option<usize>,
|
|
|
|
|
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
|
|
|
|
pub tokens_count_factors: TokensCountFactors,
|
|
|
|
|
pub capabilities: ModelCapabilities,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -30,7 +30,6 @@ impl Model {
|
|
|
|
|
name: name.into(),
|
|
|
|
|
extra_fields: None,
|
|
|
|
|
max_input_tokens: None,
|
|
|
|
|
tokens_count_factors: Default::default(),
|
|
|
|
|
capabilities: ModelCapabilities::Text,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -91,11 +90,6 @@ impl Model {
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn set_tokens_count_factors(mut self, tokens_count_factors: TokensCountFactors) -> Self {
|
|
|
|
|
self.tokens_count_factors = tokens_count_factors;
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
|
|
|
|
|
messages
|
|
|
|
|
.iter()
|
|
|
|
@ -114,17 +108,15 @@ impl Model {
|
|
|
|
|
}
|
|
|
|
|
let num_messages = messages.len();
|
|
|
|
|
let message_tokens = self.messages_tokens(messages);
|
|
|
|
|
let (per_messages, _) = self.tokens_count_factors;
|
|
|
|
|
if messages[num_messages - 1].role.is_user() {
|
|
|
|
|
num_messages * per_messages + message_tokens
|
|
|
|
|
num_messages * PER_MESSAGES_TOKENS + message_tokens
|
|
|
|
|
} else {
|
|
|
|
|
(num_messages - 1) * per_messages + message_tokens
|
|
|
|
|
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
|
|
|
|
|
let (_, bias) = self.tokens_count_factors;
|
|
|
|
|
let total_tokens = self.total_tokens(messages) + bias;
|
|
|
|
|
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
|
|
|
|
|
if let Some(max_input_tokens) = self.max_input_tokens {
|
|
|
|
|
if total_tokens >= max_input_tokens {
|
|
|
|
|
bail!("Exceed max input tokens limit")
|
|
|
|
|