refactor: optimize counting tokens (#53)

This commit is contained in:
sigoden 2023-03-09 21:18:28 +08:00 committed by GitHub
parent 9767c07eee
commit 553d0fe55b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 58 deletions

View File

@ -1,9 +1,7 @@
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS}; use super::message::{num_tokens_from_messages, Message, MessageRole};
use super::role::Role; use super::role::Role;
use super::MAX_TOKENS; use super::MAX_TOKENS;
use crate::utils::count_tokens;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -16,21 +14,17 @@ pub struct Conversation {
impl Conversation { impl Conversation {
pub fn new(role: Option<Role>) -> Self { pub fn new(role: Option<Role>) -> Self {
let tokens = if let Some(role) = role.as_ref() { let mut value = Self {
role.consume_tokens() tokens: 0,
} else {
0
};
Self {
tokens,
role, role,
messages: vec![], messages: vec![],
} };
value.tokens = num_tokens_from_messages(&value.build_emssages(""));
value
} }
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> { pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
let mut need_add_msg = true; let mut need_add_msg = true;
let mut input_tokens = count_tokens(input);
if self.messages.is_empty() { if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() { if let Some(role) = self.role.as_ref() {
self.messages.extend(role.build_emssages(input)); self.messages.extend(role.build_emssages(input));
@ -42,13 +36,12 @@ impl Conversation {
role: MessageRole::User, role: MessageRole::User,
content: input.to_string(), content: input.to_string(),
}); });
input_tokens += MESSAGE_EXTRA_TOKENS;
} }
self.messages.push(Message { self.messages.push(Message {
role: MessageRole::Assistant, role: MessageRole::Assistant,
content: output.to_string(), content: output.to_string(),
}); });
self.tokens += input_tokens + count_tokens(output) + MESSAGE_EXTRA_TOKENS; self.tokens = num_tokens_from_messages(&self.messages);
Ok(()) Ok(())
} }

View File

@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const MESSAGE_EXTRA_TOKENS: usize = 6; use crate::utils::count_tokens;
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message { pub struct Message {
@ -25,10 +25,26 @@ pub enum MessageRole {
User, User,
} }
#[test] pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
fn test_serde() { let mut num_tokens = 0;
for message in messages.iter() {
num_tokens += 4;
num_tokens += count_tokens(&message.content);
num_tokens += 1; // role always take 1 token
}
num_tokens += 2;
num_tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serde() {
assert_eq!( assert_eq!(
serde_json::to_string(&Message::new("Hello World")).unwrap(), serde_json::to_string(&Message::new("Hello World")).unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}" "{\"role\":\"user\",\"content\":\"Hello World\"}"
) )
}
} }

View File

@ -2,11 +2,11 @@ mod conversation;
mod message; mod message;
mod role; mod role;
use self::conversation::Conversation; use self::message::Message;
use self::message::{Message, MESSAGE_EXTRA_TOKENS};
use self::role::Role; use self::role::Role;
use self::{conversation::Conversation, message::num_tokens_from_messages};
use crate::utils::{count_tokens, now}; use crate::utils::now;
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Text}; use inquire::{Confirm, Text};
@ -166,8 +166,7 @@ impl Config {
bail!("") bail!("")
} }
match self.find_role(name) { match self.find_role(name) {
Some(mut role) => { Some(role) => {
role.tokens = role.consume_tokens();
let output = let output =
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into()); serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
self.role = Some(role); self.role = Some(role);
@ -201,24 +200,19 @@ impl Config {
} }
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> { pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
let content_tokens = count_tokens(content);
let check_tokens = |tokens| {
if tokens >= MAX_TOKENS {
bail!("Exceed max tokens limit")
}
Ok(())
};
let messages = if let Some(conversation) = self.conversation.as_ref() { let messages = if let Some(conversation) = self.conversation.as_ref() {
check_tokens(content_tokens + conversation.tokens)?;
conversation.build_emssages(content) conversation.build_emssages(content)
} else if let Some(role) = self.role.as_ref() { } else if let Some(role) = self.role.as_ref() {
check_tokens(content_tokens + role.tokens)?;
role.build_emssages(content) role.build_emssages(content)
} else { } else {
let message = Message::new(content); let message = Message::new(content);
check_tokens(content_tokens + MESSAGE_EXTRA_TOKENS)?;
vec![message] vec![message]
}; };
let tokens = num_tokens_from_messages(&messages);
if tokens >= MAX_TOKENS {
bail!("Exceed max tokens limit")
}
Ok(messages) Ok(messages)
} }

View File

@ -1,12 +1,9 @@
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS}; use super::message::{Message, MessageRole};
use crate::utils::count_tokens;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
const TEMP_NAME: &str = ""; const TEMP_NAME: &str = "";
const INPUT_PLACEHOLDER: &str = "__INPUT__"; const INPUT_PLACEHOLDER: &str = "__INPUT__";
const INPUT_PLACEHOLDER_TOKENS: usize = 3;
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Role { pub struct Role {
@ -19,37 +16,21 @@ pub struct Role {
pub prompt: String, pub prompt: String,
/// What sampling temperature to use, between 0 and 2 /// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>, pub temperature: Option<f64>,
/// Number of tokens
///
/// System prompt consume extra 6 tokens
#[serde(skip_deserializing)]
pub tokens: usize,
} }
impl Role { impl Role {
pub fn new(prompt: &str, temperature: Option<f64>) -> Self { pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
let mut value = Self { Self {
name: TEMP_NAME.into(), name: TEMP_NAME.into(),
prompt: prompt.into(), prompt: prompt.into(),
temperature, temperature,
tokens: 0, }
};
value.tokens = value.consume_tokens();
value
} }
pub fn is_temp(&self) -> bool { pub fn is_temp(&self) -> bool {
self.name == TEMP_NAME self.name == TEMP_NAME
} }
pub fn consume_tokens(&self) -> usize {
if self.embeded() {
count_tokens(&self.prompt) + MESSAGE_EXTRA_TOKENS - INPUT_PLACEHOLDER_TOKENS
} else {
count_tokens(&self.prompt) + 2 * MESSAGE_EXTRA_TOKENS
}
}
pub fn embeded(&self) -> bool { pub fn embeded(&self) -> bool {
self.prompt.contains(INPUT_PLACEHOLDER) self.prompt.contains(INPUT_PLACEHOLDER)
} }