From 8e5c17b5543e0e79b1b74d9723f808861f70c162 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 11 Jul 2024 12:25:30 +0800 Subject: [PATCH] refactor: improve estimate_token_length (#703) --- src/client/common.rs | 32 +++++++++++++++++--------------- src/utils/mod.rs | 32 ++++++++++++-------------------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index f25560d..725f715 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -368,7 +368,7 @@ pub trait Client: Sync + Send { ret = async { if self.global_config().read().dry_run { let content = input.echo_messages(); - let tokens = tokenize(&content); + let tokens = split_content(&content); for token in tokens { tokio::time::sleep(Duration::from_millis(10)).await; handler.text(token)?; @@ -648,22 +648,22 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> { debug!("Invalid response, status: {status}, data: {data}"); if let Some(error) = data["error"].as_object() { if let (Some(typ), Some(message)) = ( - get_str_field_from_json_map(error, "type"), - get_str_field_from_json_map(error, "message"), + json_str_from_map(error, "type"), + json_str_from_map(error, "message"), ) { bail!("{message} (type: {typ})"); } } else if let Some(error) = data["errors"][0].as_object() { if let (Some(code), Some(message)) = ( - get_u64_field_from_json_map(error, "code"), - get_str_field_from_json_map(error, "message"), + error.get("code").and_then(|v| v.as_u64()), + json_str_from_map(error, "message"), ) { bail!("{message} (status: {code})") } } else if let Some(error) = data[0]["error"].as_object() { if let (Some(status), Some(message)) = ( - get_str_field_from_json_map(error, "status"), - get_str_field_from_json_map(error, "message"), + json_str_from_map(error, "status"), + json_str_from_map(error, "message"), ) { bail!("{message} (status: {status})") } @@ -678,20 +678,13 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> { bail!("Invalid response data: {data} (status: {status})"); } -pub fn get_str_field_from_json_map<'a>( +pub fn json_str_from_map<'a>( map: &'a serde_json::Map, field_name: &str, ) -> Option<&'a str> { map.get(field_name).and_then(|v| v.as_str()) } -pub fn get_u64_field_from_json_map( - map: &serde_json::Map, - field_name: &str, -) -> Option { - map.get(field_name).and_then(|v| v.as_u64()) -} - pub fn maybe_catch_error(data: &Value) -> Result<()> { if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) { debug!("Invalid response: {}", data); @@ -768,3 +761,12 @@ fn to_json(kind: &PromptKind, value: &str) -> Value { }, } } + +fn split_content(text: &str) -> Vec<&str> { + if text.is_ascii() { + text.split_inclusive(|c: char| c.is_ascii_whitespace()) + .collect() + } else { + unicode_segmentation::UnicodeSegmentation::graphemes(text, true).collect() + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 35f41a2..aa63f3b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -23,6 +23,7 @@ use fancy_regex::Regex; use is_terminal::IsTerminal; use lazy_static::lazy_static; use std::{env, path::PathBuf, process}; +use unicode_segmentation::UnicodeSegmentation; lazy_static! { pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap(); @@ -42,31 +43,22 @@ pub fn get_env_name(key: &str) -> String { ) } -pub fn tokenize(text: &str) -> Vec<&str> { - if text.is_ascii() { - text.split_inclusive(|c: char| c.is_ascii_whitespace()) - .collect() - } else { - unicode_segmentation::UnicodeSegmentation::graphemes(text, true).collect() - } -} - pub fn estimate_token_length(text: &str) -> usize { - let mut token_length: f32 = 0.0; - - for char in text.chars() { - if char.is_ascii() { - if char.is_ascii_alphabetic() { - token_length += 0.25; + let words: Vec<&str> = text.unicode_words().collect(); + let mut output: f32 = 0.0; + for word in words { + if word.is_ascii() { + output += 1.3; + } else { + let count = word.chars().count(); + if count == 1 { + output += 1.0 } else { - token_length += 0.5; + output += (count as f32) * 0.5; } - } else { - token_length += 1.5; } } - - token_length.ceil() as usize + output.ceil() as usize } pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option {