refactor: improve estimate_token_length (#703)

pull/704/head
sigoden 2 months ago committed by GitHub
parent 0264ab80ab
commit 8e5c17b554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -368,7 +368,7 @@ pub trait Client: Sync + Send {
ret = async {
if self.global_config().read().dry_run {
let content = input.echo_messages();
let tokens = tokenize(&content);
let tokens = split_content(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(10)).await;
handler.text(token)?;
@ -648,22 +648,22 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> {
debug!("Invalid response, status: {status}, data: {data}");
if let Some(error) = data["error"].as_object() {
if let (Some(typ), Some(message)) = (
get_str_field_from_json_map(error, "type"),
get_str_field_from_json_map(error, "message"),
json_str_from_map(error, "type"),
json_str_from_map(error, "message"),
) {
bail!("{message} (type: {typ})");
}
} else if let Some(error) = data["errors"][0].as_object() {
if let (Some(code), Some(message)) = (
get_u64_field_from_json_map(error, "code"),
get_str_field_from_json_map(error, "message"),
error.get("code").and_then(|v| v.as_u64()),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {code})")
}
} else if let Some(error) = data[0]["error"].as_object() {
if let (Some(status), Some(message)) = (
get_str_field_from_json_map(error, "status"),
get_str_field_from_json_map(error, "message"),
json_str_from_map(error, "status"),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {status})")
}
@ -678,20 +678,13 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> {
bail!("Invalid response data: {data} (status: {status})");
}
pub fn get_str_field_from_json_map<'a>(
pub fn json_str_from_map<'a>(
map: &'a serde_json::Map<String, Value>,
field_name: &str,
) -> Option<&'a str> {
map.get(field_name).and_then(|v| v.as_str())
}
pub fn get_u64_field_from_json_map(
map: &serde_json::Map<String, Value>,
field_name: &str,
) -> Option<u64> {
map.get(field_name).and_then(|v| v.as_u64())
}
pub fn maybe_catch_error(data: &Value) -> Result<()> {
if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) {
debug!("Invalid response: {}", data);
@ -768,3 +761,12 @@ fn to_json(kind: &PromptKind, value: &str) -> Value {
},
}
}
fn split_content(text: &str) -> Vec<&str> {
if text.is_ascii() {
text.split_inclusive(|c: char| c.is_ascii_whitespace())
.collect()
} else {
unicode_segmentation::UnicodeSegmentation::graphemes(text, true).collect()
}
}

@ -23,6 +23,7 @@ use fancy_regex::Regex;
use is_terminal::IsTerminal;
use lazy_static::lazy_static;
use std::{env, path::PathBuf, process};
use unicode_segmentation::UnicodeSegmentation;
lazy_static! {
pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap();
@ -42,31 +43,22 @@ pub fn get_env_name(key: &str) -> String {
)
}
pub fn tokenize(text: &str) -> Vec<&str> {
if text.is_ascii() {
text.split_inclusive(|c: char| c.is_ascii_whitespace())
.collect()
} else {
unicode_segmentation::UnicodeSegmentation::graphemes(text, true).collect()
}
}
pub fn estimate_token_length(text: &str) -> usize {
let mut token_length: f32 = 0.0;
for char in text.chars() {
if char.is_ascii() {
if char.is_ascii_alphabetic() {
token_length += 0.25;
let words: Vec<&str> = text.unicode_words().collect();
let mut output: f32 = 0.0;
for word in words {
if word.is_ascii() {
output += 1.3;
} else {
let count = word.chars().count();
if count == 1 {
output += 1.0
} else {
token_length += 0.5;
output += (count as f32) * 0.5;
}
} else {
token_length += 1.5;
}
}
token_length.ceil() as usize
output.ceil() as usize
}
pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option<bool> {

Loading…
Cancel
Save