//! Use tiktoken for count tokens //! //! Fork from https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken #![allow(unused)] use anyhow::{anyhow, Result}; use base64::{engine::general_purpose, Engine as _}; use fancy_regex::Regex; use lazy_static::lazy_static; use parking_lot::Mutex; use rustc_hash::FxHashMap as HashMap; use std::collections::HashSet; use std::sync::Arc; use tokio::task; pub fn cl100k_base() -> Result { let cl100k_base = include_str!("../../assets/cl100k_base.tiktoken"); let mut encoder = HashMap::default(); for line in cl100k_base.lines() { let mut parts = line.split(' '); let raw = parts.next().unwrap(); let token = &general_purpose::STANDARD.decode(raw)?; let rank: usize = parts.next().unwrap().parse().unwrap(); encoder.insert(token.clone(), rank); } let mut special_tokens = HashMap::default(); special_tokens.insert(String::from("<|endoftext|>"), 100257); special_tokens.insert(String::from("<|fim_prefix|>"), 100258); special_tokens.insert(String::from("<|fim_middle|>"), 100259); special_tokens.insert(String::from("<|fim_suffix|>"), 100260); special_tokens.insert(String::from("<|endofprompt|>"), 100276); CoreBPE::new( encoder, special_tokens, "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } pub fn cl100k_base_singleton() -> Arc> { lazy_static! { static ref CL100K_BASE: Arc> = Arc::new(Mutex::new(cl100k_base().unwrap())); } CL100K_BASE.clone() } pub async fn decode_async(bpe: Arc>, tokens: Vec) -> Result { task::spawn_blocking(move || bpe.lock().decode(tokens)).await? } pub async fn encode_async(bpe: Arc>, text: &str) -> Result> { let text = text.to_string(); let r = task::spawn_blocking(move || bpe.lock().encode_with_special_tokens(&text)).await?; Ok(r) } fn _byte_pair_merge(piece: &[u8], ranks: &HashMap, usize>) -> Vec> { let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); // If you have n parts and m merges, this does O(mn) work // We could do something with a heap and do O(m log n) work // Note that we hash bytes, not token pairs. As long as we train BPE the way we // currently do, this is equivalent. An easy way to break this would be to decouple // merge priority from token index or to prevent specific token merges. loop { if parts.len() == 1 { break; } let mut min_rank: Option<(usize, usize)> = None; for i in 0..parts.len() - 1 { let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) { *r } else { continue; }; if min_rank.is_none() || rank < min_rank.unwrap().0 { min_rank = Some((rank, i)); } } if let Some((_, i)) = min_rank { parts[i] = parts[i].start..parts[i + 1].end; parts.remove(i + 1); } else { break; } } parts } pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { if piece.len() == 1 { return vec![ranks[piece]]; } _byte_pair_merge(piece, ranks) .iter() .map(|p| ranks[&piece[p.start..p.end]]) .collect() } pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { if piece.len() == 1 { return vec![piece]; } _byte_pair_merge(piece, ranks) .iter() .map(|p| &piece[p.start..p.end]) .collect() } // Various performance notes: // // Regex // ===== // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than // the usual regex we use. // // However, given that we're using a regex parse-able by `regex`, there isn't much difference // between using the `regex` crate and using the `fancy_regex` crate. // // Caching // ======= // The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. // Originally, we had one too! Without it, we were only vaguely faster than Python. // I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance // noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect // multi-threaded performance even when I only had readers (maybed I messed something up?). // Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! // These are exactly the set or merges that are likely to be hot. And now we don't have to think // about interior mutability, memory use, or cloning. // // Hashing // ======= // We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. pub struct CoreBPE { encoder: HashMap, usize>, special_tokens_encoder: HashMap, decoder: HashMap>, special_tokens_decoder: HashMap>, regex: Regex, special_regex: Regex, sorted_token_bytes: Vec>, } impl CoreBPE { fn _get_regex(&self) -> &Regex { &self.regex } fn _get_special_regex(&self) -> &Regex { &self.special_regex } fn _decode_native(&self, tokens: &[usize]) -> Vec { let mut ret = Vec::with_capacity(tokens.len() * 2); for token in tokens { let token_bytes = self .decoder .get(token) .unwrap_or_else(|| &self.special_tokens_decoder[token]); ret.extend(token_bytes); } ret } fn _encode_ordinary_native(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) let regex = self._get_regex(); let mut ret = vec![]; for mat in regex.find_iter(text) { let piece = mat.unwrap().as_str().as_bytes(); if let Some(token) = self.encoder.get(piece) { ret.push(*token); continue; } ret.extend(&byte_pair_encode(piece, &self.encoder)); } ret } fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { let special_regex = self._get_special_regex(); let regex = self._get_regex(); let mut ret = vec![]; let mut start = 0; let mut last_piece_token_len = 0; loop { let mut next_special; let mut start_find = start; loop { // Find the next allowed special token, if any next_special = special_regex.find_from_pos(text, start_find).unwrap(); match next_special { Some(m) => { if allowed_special.contains(&text[m.start()..m.end()]) { break; } start_find = m.start() + 1; } None => break, } } let end = next_special.map_or(text.len(), |m| m.start()); // Okay, here we go, compare this logic to _encode_ordinary_native for mat in regex.find_iter(&text[start..end]) { let piece = mat.unwrap().as_str().as_bytes(); if let Some(token) = self.encoder.get(piece) { last_piece_token_len = 1; ret.push(*token); continue; } let tokens = byte_pair_encode(piece, &self.encoder); last_piece_token_len = tokens.len(); ret.extend(&tokens); } match next_special { // And here we push the special token Some(m) => { let piece = m.as_str(); let token = self.special_tokens_encoder[piece]; ret.push(token); start = m.end(); last_piece_token_len = 0; } None => break, } } // last_piece_token_len is how many tokens came from the last regex split. This is used // for determining unstable tokens, since you can't merge across (stable) regex splits (ret, last_piece_token_len) } fn _increase_last_piece_token_len( &self, tokens: Vec, mut last_piece_token_len: usize, ) -> (Vec, usize) { // Unfortunately, the locations where our regex splits can be unstable. // For the purposes of determining unstable tokens, unstable regex splitting // is only a problem if a split that was present disappears, since this can // lead to merging of tokens otherwise thought to be stable. // cl100k_base makes our life hard by including the \s*[\r\n]+ // pattern. This can e.g. cause "\n" + " " to become "\n \n". // Here is a quick and dirty fix: { let token_is_all_space = |token| { self.decoder .get(token) .map(|token_bytes| { token_bytes .iter() .rev() .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) }) .unwrap_or(false) }; if last_piece_token_len > 0 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) { while (last_piece_token_len < tokens.len()) && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) { last_piece_token_len += 1; } } } debug_assert!(last_piece_token_len <= tokens.len()); (tokens, last_piece_token_len) } fn _encode_unstable_native( &self, text: &str, allowed_special: &HashSet<&str>, ) -> (Vec, HashSet>) { let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); if last_piece_token_len == 0 { // If last_piece_token_len is zero, the last token was a special token and we have // no unstable bytes return (tokens, HashSet::new()); } let (mut tokens, last_piece_token_len) = self._increase_last_piece_token_len(tokens, last_piece_token_len); let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); tokens.truncate(tokens.len() - last_piece_token_len); // TODO: we should try harder to find additional stable tokens // This would reduce the amount of retokenising when determining completions // Refer to the logic in an older version of this file let mut completions = HashSet::new(); if unstable_bytes.is_empty() { return (tokens, completions); } // This is the easy bit. Just find all single tokens that start with unstable_bytes // (including tokens that exactly match unstable_bytes) // Separating this from the loop below helps with performance in a common case. let mut point = self .sorted_token_bytes .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); while point < self.sorted_token_bytes.len() && self.sorted_token_bytes[point].starts_with(&unstable_bytes) { completions.insert(vec![ self.encoder[self.sorted_token_bytes[point].as_slice()], ]); point += 1; } // Now apply even more brute force. At every (other) possible position for the straddling // token, concatenate additional bytes from that token (if any) to unstable_bytes, // and retokenise the whole thing and see what we get. for i in 1..unstable_bytes.len() { let prefix = &unstable_bytes[..i]; let suffix = &unstable_bytes[i..]; let mut point = self .sorted_token_bytes .partition_point(|x| x.as_slice() < suffix); // TODO: Perf optimisation if suffix starts with " "? while point < self.sorted_token_bytes.len() && self.sorted_token_bytes[point].starts_with(suffix) { let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); let encoded = match std::str::from_utf8(&possibility) { // Morally, this is byte_pair_encode(&possibility, &self.encoder) // But we might have introduced a regex split which would prevent merges. // (particularly possible in the presence of unstable regex splits) // So convert to UTF-8 and do regex splitting. // E.g. with cl100k_base " !" gets split to " " + " !", // but byte_pair_encode(" !") != byte_pair_encode(" ") Ok(s) => self._encode_ordinary_native(s), // Technically, whether or not this arm is correct depends on whether there // would be a regex split before the UTF-8 truncation point. // Probably niche enough that no one will ever notice (after all, people didn't // notice all the big holes in the previous unstable token implementation) Err(_) => byte_pair_encode(&possibility, &self.encoder), // Something like the following is intriguing but incorrect: // Err(e) => self._encode_ordinary_native(unsafe { // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) // }), }; let mut seq = Vec::new(); let mut seq_len = 0; for token in encoded { seq.push(token); seq_len += self.decoder[&token].len(); if seq_len >= unstable_bytes.len() { break; } } completions.insert(seq); point += 1; } } // This is also not straightforward. While we generally assume that regex splits are stable, // unfortunately, they are not. That is, if adding bytes were to make a split appear in // unstable_bytes, this could make tokens possible which our logic would otherwise think // would be merged. // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. // Here is a quick and dirty fix: // This isn't right if we ever remove \s+(?!\S) if unstable_bytes.len() > 1 { let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); if unstable_bytes.len() - last_decoded.1 > 0 && last_decoded.0.map_or(false, |c| c.is_whitespace()) { let mut reencoded = byte_pair_encode( &unstable_bytes[..unstable_bytes.len() - last_decoded.1], &self.encoder, ); reencoded.extend(byte_pair_encode( &unstable_bytes[unstable_bytes.len() - last_decoded.1..], &self.encoder, )); completions.insert(reencoded); } } (tokens, completions) } } impl CoreBPE { fn new( encoder: HashMap, usize>, special_tokens_encoder: HashMap, pattern: &str, ) -> Result { let regex = Regex::new(pattern)?; let special_regex = { let _parts = special_tokens_encoder .keys() .map(|s| fancy_regex::escape(s)) .collect::>(); Regex::new(&_parts.join("|"))? }; let decoder: HashMap> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); assert!(encoder.len() == decoder.len()); let special_tokens_decoder: HashMap> = special_tokens_encoder .iter() .map(|(k, v)| (*v, k.as_bytes().to_vec())) .collect(); // Clone because I don't know how to tell Rust I'm not going to change the map let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); sorted_token_bytes.sort(); Ok(CoreBPE { encoder, special_tokens_encoder, decoder, special_tokens_decoder, regex, special_regex, sorted_token_bytes, }) } // ==================== // Encoding // ==================== pub fn encode_ordinary(&self, text: &str) -> Vec { self._encode_ordinary_native(text) } pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec { self._encode_native(text, &allowed_special).0 } pub fn encode_with_special_tokens(&self, text: &str) -> Vec { let allowed_special = self .special_tokens_encoder .keys() .map(|s| s.as_str()) .collect(); self._encode_native(text, &allowed_special).0 } fn _encode_bytes(&self, bytes: &[u8]) -> Vec { match std::str::from_utf8(bytes) { Ok(text) => self._encode_ordinary_native(text), Err(e) => { let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); let (mut tokens, last_piece_token_len) = self._increase_last_piece_token_len(tokens, last_piece_token_len); if !tokens.is_empty() && last_piece_token_len > 0 { // Lop off the tokens from the last piece and run BPE on the remaining bytes // Somewhat niche, but this may not be correct if we'd have had a regex // split between the valid UTF-8 and the invalid bytes, which is why this // method is private let mut unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); tokens.truncate(tokens.len() - last_piece_token_len); tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); } tokens } } } #[allow(dead_code)] fn encode_with_unstable( &self, text: &str, allowed_special: HashSet<&str>, ) -> (Vec, HashSet>) { self._encode_unstable_native(text, &allowed_special) } #[allow(dead_code)] fn encode_single_token(&self, piece: &[u8]) -> Result { if let Some(token) = self.encoder.get(piece).copied() { return Ok(token); } if let Ok(piece_str) = std::str::from_utf8(piece) { if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { return Ok(token); } } Err(anyhow!("Token not found in the vocabulary: {:?}", piece)) } #[allow(dead_code)] fn encode_single_piece(&self, piece: &[u8]) -> Vec { if let Some(token) = self.encoder.get(piece) { return vec![*token]; } byte_pair_encode(piece, &self.encoder) } // ==================== // Decoding // ==================== pub fn decode_bytes(&self, tokens: Vec) -> Vec { self._decode_native(&tokens) } pub fn decode(&self, tokens: Vec) -> Result { match String::from_utf8(self._decode_native(&tokens)) { Ok(text) => Ok(text), Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)), } } pub fn decode_single_token_bytes(&self, token: usize) -> Result> { if let Some(bytes) = self.decoder.get(&token) { return Ok(bytes.clone()); } if let Some(bytes) = self.special_tokens_decoder.get(&token) { return Ok(bytes.clone()); } Err(anyhow!("Token not found in the vocabulary: {}", token)) } // ==================== // Miscellaneous // ==================== #[allow(dead_code)] fn token_byte_values(&self) -> Vec> { self.sorted_token_bytes.clone() } } #[cfg(test)] mod tests { use super::*; use rustc_hash::FxHashMap as HashMap; #[test] fn very_simple_test() { let mut ranks = HashMap::default(); ranks.insert(b"ab".to_vec(), 1); ranks.insert(b"cd".to_vec(), 2); let res = byte_pair_split(b"abcd", &ranks); assert_eq!(res, vec![b"ab", b"cd"]); } #[test] fn cl100k_base_test() { let bpe = cl100k_base().unwrap(); let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces"); let decoded = bpe.decode(tokens.clone()).unwrap(); assert_eq!(decoded, "This is a test with a lot of spaces"); assert_eq!( tokens, vec![2028, 374, 264, 1296, 260, 449, 264, 2763, 315, 12908] ); } }