mirror of
https://github.com/sigoden/aichat
synced 2024-11-18 09:28:27 +00:00
refactor: optimize counting tokens (#53)
This commit is contained in:
parent
9767c07eee
commit
553d0fe55b
@ -1,9 +1,7 @@
|
|||||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
use super::message::{num_tokens_from_messages, Message, MessageRole};
|
||||||
use super::role::Role;
|
use super::role::Role;
|
||||||
use super::MAX_TOKENS;
|
use super::MAX_TOKENS;
|
||||||
|
|
||||||
use crate::utils::count_tokens;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@ -16,21 +14,17 @@ pub struct Conversation {
|
|||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
pub fn new(role: Option<Role>) -> Self {
|
pub fn new(role: Option<Role>) -> Self {
|
||||||
let tokens = if let Some(role) = role.as_ref() {
|
let mut value = Self {
|
||||||
role.consume_tokens()
|
tokens: 0,
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
Self {
|
|
||||||
tokens,
|
|
||||||
role,
|
role,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
}
|
};
|
||||||
|
value.tokens = num_tokens_from_messages(&value.build_emssages(""));
|
||||||
|
value
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
|
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
|
||||||
let mut need_add_msg = true;
|
let mut need_add_msg = true;
|
||||||
let mut input_tokens = count_tokens(input);
|
|
||||||
if self.messages.is_empty() {
|
if self.messages.is_empty() {
|
||||||
if let Some(role) = self.role.as_ref() {
|
if let Some(role) = self.role.as_ref() {
|
||||||
self.messages.extend(role.build_emssages(input));
|
self.messages.extend(role.build_emssages(input));
|
||||||
@ -42,13 +36,12 @@ impl Conversation {
|
|||||||
role: MessageRole::User,
|
role: MessageRole::User,
|
||||||
content: input.to_string(),
|
content: input.to_string(),
|
||||||
});
|
});
|
||||||
input_tokens += MESSAGE_EXTRA_TOKENS;
|
|
||||||
}
|
}
|
||||||
self.messages.push(Message {
|
self.messages.push(Message {
|
||||||
role: MessageRole::Assistant,
|
role: MessageRole::Assistant,
|
||||||
content: output.to_string(),
|
content: output.to_string(),
|
||||||
});
|
});
|
||||||
self.tokens += input_tokens + count_tokens(output) + MESSAGE_EXTRA_TOKENS;
|
self.tokens = num_tokens_from_messages(&self.messages);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const MESSAGE_EXTRA_TOKENS: usize = 6;
|
use crate::utils::count_tokens;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
@ -25,6 +25,21 @@ pub enum MessageRole {
|
|||||||
User,
|
User,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
|
||||||
|
let mut num_tokens = 0;
|
||||||
|
for message in messages.iter() {
|
||||||
|
num_tokens += 4;
|
||||||
|
num_tokens += count_tokens(&message.content);
|
||||||
|
num_tokens += 1; // role always take 1 token
|
||||||
|
}
|
||||||
|
num_tokens += 2;
|
||||||
|
num_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_serde() {
|
fn test_serde() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -32,3 +47,4 @@ fn test_serde() {
|
|||||||
"{\"role\":\"user\",\"content\":\"Hello World\"}"
|
"{\"role\":\"user\",\"content\":\"Hello World\"}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@ -2,11 +2,11 @@ mod conversation;
|
|||||||
mod message;
|
mod message;
|
||||||
mod role;
|
mod role;
|
||||||
|
|
||||||
use self::conversation::Conversation;
|
use self::message::Message;
|
||||||
use self::message::{Message, MESSAGE_EXTRA_TOKENS};
|
|
||||||
use self::role::Role;
|
use self::role::Role;
|
||||||
|
use self::{conversation::Conversation, message::num_tokens_from_messages};
|
||||||
|
|
||||||
use crate::utils::{count_tokens, now};
|
use crate::utils::now;
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Context, Result};
|
use anyhow::{anyhow, bail, Context, Result};
|
||||||
use inquire::{Confirm, Text};
|
use inquire::{Confirm, Text};
|
||||||
@ -166,8 +166,7 @@ impl Config {
|
|||||||
bail!("")
|
bail!("")
|
||||||
}
|
}
|
||||||
match self.find_role(name) {
|
match self.find_role(name) {
|
||||||
Some(mut role) => {
|
Some(role) => {
|
||||||
role.tokens = role.consume_tokens();
|
|
||||||
let output =
|
let output =
|
||||||
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
|
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
|
||||||
self.role = Some(role);
|
self.role = Some(role);
|
||||||
@ -201,24 +200,19 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
|
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
|
||||||
let content_tokens = count_tokens(content);
|
|
||||||
let check_tokens = |tokens| {
|
|
||||||
if tokens >= MAX_TOKENS {
|
|
||||||
bail!("Exceed max tokens limit")
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
};
|
|
||||||
let messages = if let Some(conversation) = self.conversation.as_ref() {
|
let messages = if let Some(conversation) = self.conversation.as_ref() {
|
||||||
check_tokens(content_tokens + conversation.tokens)?;
|
|
||||||
conversation.build_emssages(content)
|
conversation.build_emssages(content)
|
||||||
} else if let Some(role) = self.role.as_ref() {
|
} else if let Some(role) = self.role.as_ref() {
|
||||||
check_tokens(content_tokens + role.tokens)?;
|
|
||||||
role.build_emssages(content)
|
role.build_emssages(content)
|
||||||
} else {
|
} else {
|
||||||
let message = Message::new(content);
|
let message = Message::new(content);
|
||||||
check_tokens(content_tokens + MESSAGE_EXTRA_TOKENS)?;
|
|
||||||
vec![message]
|
vec![message]
|
||||||
};
|
};
|
||||||
|
let tokens = num_tokens_from_messages(&messages);
|
||||||
|
if tokens >= MAX_TOKENS {
|
||||||
|
bail!("Exceed max tokens limit")
|
||||||
|
}
|
||||||
|
|
||||||
Ok(messages)
|
Ok(messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
use super::message::{Message, MessageRole};
|
||||||
|
|
||||||
use crate::utils::count_tokens;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
const TEMP_NAME: &str = "P";
|
const TEMP_NAME: &str = "P";
|
||||||
const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
||||||
const INPUT_PLACEHOLDER_TOKENS: usize = 3;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Role {
|
pub struct Role {
|
||||||
@ -19,37 +16,21 @@ pub struct Role {
|
|||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
/// What sampling temperature to use, between 0 and 2
|
/// What sampling temperature to use, between 0 and 2
|
||||||
pub temperature: Option<f64>,
|
pub temperature: Option<f64>,
|
||||||
/// Number of tokens
|
|
||||||
///
|
|
||||||
/// System prompt consume extra 6 tokens
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub tokens: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Role {
|
impl Role {
|
||||||
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
|
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
|
||||||
let mut value = Self {
|
Self {
|
||||||
name: TEMP_NAME.into(),
|
name: TEMP_NAME.into(),
|
||||||
prompt: prompt.into(),
|
prompt: prompt.into(),
|
||||||
temperature,
|
temperature,
|
||||||
tokens: 0,
|
}
|
||||||
};
|
|
||||||
value.tokens = value.consume_tokens();
|
|
||||||
value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_temp(&self) -> bool {
|
pub fn is_temp(&self) -> bool {
|
||||||
self.name == TEMP_NAME
|
self.name == TEMP_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn consume_tokens(&self) -> usize {
|
|
||||||
if self.embeded() {
|
|
||||||
count_tokens(&self.prompt) + MESSAGE_EXTRA_TOKENS - INPUT_PLACEHOLDER_TOKENS
|
|
||||||
} else {
|
|
||||||
count_tokens(&self.prompt) + 2 * MESSAGE_EXTRA_TOKENS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn embeded(&self) -> bool {
|
pub fn embeded(&self) -> bool {
|
||||||
self.prompt.contains(INPUT_PLACEHOLDER)
|
self.prompt.contains(INPUT_PLACEHOLDER)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user