mirror of
https://github.com/sigoden/aichat
synced 2024-11-18 09:28:27 +00:00
refactor: optimize counting tokens (#53)
This commit is contained in:
parent
9767c07eee
commit
553d0fe55b
@ -1,9 +1,7 @@
|
||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
||||
use super::message::{num_tokens_from_messages, Message, MessageRole};
|
||||
use super::role::Role;
|
||||
use super::MAX_TOKENS;
|
||||
|
||||
use crate::utils::count_tokens;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -16,21 +14,17 @@ pub struct Conversation {
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(role: Option<Role>) -> Self {
|
||||
let tokens = if let Some(role) = role.as_ref() {
|
||||
role.consume_tokens()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
Self {
|
||||
tokens,
|
||||
let mut value = Self {
|
||||
tokens: 0,
|
||||
role,
|
||||
messages: vec![],
|
||||
}
|
||||
};
|
||||
value.tokens = num_tokens_from_messages(&value.build_emssages(""));
|
||||
value
|
||||
}
|
||||
|
||||
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
|
||||
let mut need_add_msg = true;
|
||||
let mut input_tokens = count_tokens(input);
|
||||
if self.messages.is_empty() {
|
||||
if let Some(role) = self.role.as_ref() {
|
||||
self.messages.extend(role.build_emssages(input));
|
||||
@ -42,13 +36,12 @@ impl Conversation {
|
||||
role: MessageRole::User,
|
||||
content: input.to_string(),
|
||||
});
|
||||
input_tokens += MESSAGE_EXTRA_TOKENS;
|
||||
}
|
||||
self.messages.push(Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: output.to_string(),
|
||||
});
|
||||
self.tokens += input_tokens + count_tokens(output) + MESSAGE_EXTRA_TOKENS;
|
||||
self.tokens = num_tokens_from_messages(&self.messages);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const MESSAGE_EXTRA_TOKENS: usize = 6;
|
||||
use crate::utils::count_tokens;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
@ -25,10 +25,26 @@ pub enum MessageRole {
|
||||
User,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde() {
|
||||
pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
|
||||
let mut num_tokens = 0;
|
||||
for message in messages.iter() {
|
||||
num_tokens += 4;
|
||||
num_tokens += count_tokens(&message.content);
|
||||
num_tokens += 1; // role always take 1 token
|
||||
}
|
||||
num_tokens += 2;
|
||||
num_tokens
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_serde() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&Message::new("Hello World")).unwrap(),
|
||||
"{\"role\":\"user\",\"content\":\"Hello World\"}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -2,11 +2,11 @@ mod conversation;
|
||||
mod message;
|
||||
mod role;
|
||||
|
||||
use self::conversation::Conversation;
|
||||
use self::message::{Message, MESSAGE_EXTRA_TOKENS};
|
||||
use self::message::Message;
|
||||
use self::role::Role;
|
||||
use self::{conversation::Conversation, message::num_tokens_from_messages};
|
||||
|
||||
use crate::utils::{count_tokens, now};
|
||||
use crate::utils::now;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use inquire::{Confirm, Text};
|
||||
@ -166,8 +166,7 @@ impl Config {
|
||||
bail!("")
|
||||
}
|
||||
match self.find_role(name) {
|
||||
Some(mut role) => {
|
||||
role.tokens = role.consume_tokens();
|
||||
Some(role) => {
|
||||
let output =
|
||||
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
|
||||
self.role = Some(role);
|
||||
@ -201,24 +200,19 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
|
||||
let content_tokens = count_tokens(content);
|
||||
let check_tokens = |tokens| {
|
||||
if tokens >= MAX_TOKENS {
|
||||
bail!("Exceed max tokens limit")
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
let messages = if let Some(conversation) = self.conversation.as_ref() {
|
||||
check_tokens(content_tokens + conversation.tokens)?;
|
||||
conversation.build_emssages(content)
|
||||
} else if let Some(role) = self.role.as_ref() {
|
||||
check_tokens(content_tokens + role.tokens)?;
|
||||
role.build_emssages(content)
|
||||
} else {
|
||||
let message = Message::new(content);
|
||||
check_tokens(content_tokens + MESSAGE_EXTRA_TOKENS)?;
|
||||
vec![message]
|
||||
};
|
||||
let tokens = num_tokens_from_messages(&messages);
|
||||
if tokens >= MAX_TOKENS {
|
||||
bail!("Exceed max tokens limit")
|
||||
}
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,9 @@
|
||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
||||
|
||||
use crate::utils::count_tokens;
|
||||
use super::message::{Message, MessageRole};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const TEMP_NAME: &str = "P";
|
||||
const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
||||
const INPUT_PLACEHOLDER_TOKENS: usize = 3;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Role {
|
||||
@ -19,37 +16,21 @@ pub struct Role {
|
||||
pub prompt: String,
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
pub temperature: Option<f64>,
|
||||
/// Number of tokens
|
||||
///
|
||||
/// System prompt consume extra 6 tokens
|
||||
#[serde(skip_deserializing)]
|
||||
pub tokens: usize,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
|
||||
let mut value = Self {
|
||||
Self {
|
||||
name: TEMP_NAME.into(),
|
||||
prompt: prompt.into(),
|
||||
temperature,
|
||||
tokens: 0,
|
||||
};
|
||||
value.tokens = value.consume_tokens();
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_temp(&self) -> bool {
|
||||
self.name == TEMP_NAME
|
||||
}
|
||||
|
||||
pub fn consume_tokens(&self) -> usize {
|
||||
if self.embeded() {
|
||||
count_tokens(&self.prompt) + MESSAGE_EXTRA_TOKENS - INPUT_PLACEHOLDER_TOKENS
|
||||
} else {
|
||||
count_tokens(&self.prompt) + 2 * MESSAGE_EXTRA_TOKENS
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embeded(&self) -> bool {
|
||||
self.prompt.contains(INPUT_PLACEHOLDER)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user