From a0bd6e1d5d68c718a0533037364e3df9ad13da96 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 10 Apr 2024 20:06:57 +0800 Subject: [PATCH] refactor: extract json stream handling (#398) --- src/client/cohere.rs | 60 ++++++---------------------------------- src/client/common.rs | 63 ++++++++++++++++++++++++++++++++++++++++++ src/client/vertexai.rs | 58 ++++++-------------------------------- 3 files changed, 80 insertions(+), 101 deletions(-) diff --git a/src/client/cohere.rs b/src/client/cohere.rs index ca77035..a93dc71 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,13 +1,12 @@ use super::{ - message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, PromptType, - SendData, TokensCountFactors, + json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, + PromptType, SendData, TokensCountFactors, }; use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{bail, Result}; use async_trait::async_trait; -use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -107,55 +106,14 @@ pub(crate) async fn send_message_streaming( let data: Value = res.json().await?; check_error(&data)?; } else { - let mut buffer = vec![]; - let mut cursor = 0; - let mut start = 0; - let mut balances = vec![]; - let mut quoting = false; - let mut stream = res.bytes_stream(); - while let Some(chunk) = stream.next().await { - let chunk = chunk?; - let chunk = std::str::from_utf8(&chunk)?; - buffer.extend(chunk.chars()); - for i in cursor..buffer.len() { - let ch = buffer[i]; - if quoting { - if ch == '"' && buffer[i - 1] != '\\' { - quoting = false; - } - continue; - } - match ch { - '"' => quoting = true, - '{' => { - if balances.is_empty() { - start = i; - } - balances.push(ch); - } - '[' => { - if start != 0 { - balances.push(ch); - } - } - '}' => { - balances.pop(); - if balances.is_empty() { - let value: String = buffer[start..=i].iter().collect(); - let value: Value = serde_json::from_str(&value)?; - if let Some("text-generation") = value["event_type"].as_str() { - handler.text(extract_text(&value)?)?; - } - } - } - ']' => { - balances.pop(); - } - _ => {} - } + let handle = |value: &str| -> Result<()> { + let value: Value = serde_json::from_str(value)?; + if let Some("text-generation") = value["event_type"].as_str() { + handler.text(extract_text(&value)?)?; } - cursor = buffer.len(); - } + Ok(()) + }; + json_stream(res.bytes_stream(), handle).await?; } Ok(()) } diff --git a/src/client/common.rs b/src/client/common.rs index 4acc707..9171e3a 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -11,6 +11,7 @@ use crate::{ use anyhow::{Context, Result}; use async_trait::async_trait; +use futures_util::{Stream, StreamExt}; use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -365,6 +366,68 @@ pub fn patch_system_message(messages: &mut Vec) { } } +pub async fn json_stream(mut stream: S, mut handle: F) -> Result<()> +where + S: Stream> + Unpin, + F: FnMut(&str) -> Result<()>, +{ + let mut buffer = vec![]; + let mut cursor = 0; + let mut start = 0; + let mut balances = vec![]; + let mut quoting = false; + let mut escape = false; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + let chunk = std::str::from_utf8(&chunk)?; + buffer.extend(chunk.chars()); + for i in cursor..buffer.len() { + let ch = buffer[i]; + if quoting { + if ch == '\\' { + escape = !escape; + } else { + if !escape && ch == '"' { + quoting = false; + } + escape = false; + } + continue; + } + match ch { + '"' => { + quoting = true; + escape = false; + } + '{' => { + if balances.is_empty() { + start = i; + } + balances.push(ch); + } + '[' => { + if start != 0 { + balances.push(ch); + } + } + '}' => { + balances.pop(); + if balances.is_empty() { + let value: String = buffer[start..=i].iter().collect(); + handle(&value)?; + } + } + ']' => { + balances.pop(); + } + _ => {} + } + } + cursor = buffer.len(); + } + Ok(()) +} + fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) { let segs: Vec<&str> = path.split('.').collect(); match segs.as_slice() { diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 1c1fd46..babbd23 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,6 +1,6 @@ use super::{ - message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, SendData, - TokensCountFactors, VertexAIClient, + json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, + SendData, TokensCountFactors, VertexAIClient, }; use crate::{render::ReplyHandler, utils::PromptKind}; @@ -8,7 +8,6 @@ use crate::{render::ReplyHandler, utils::PromptKind}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use chrono::{Duration, Utc}; -use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -136,53 +135,12 @@ pub(crate) async fn send_message_streaming( let data: Value = res.json().await?; check_error(&data)?; } else { - let mut buffer = vec![]; - let mut cursor = 0; - let mut start = 0; - let mut balances = vec![]; - let mut quoting = false; - let mut stream = res.bytes_stream(); - while let Some(chunk) = stream.next().await { - let chunk = chunk?; - let chunk = std::str::from_utf8(&chunk)?; - buffer.extend(chunk.chars()); - for i in cursor..buffer.len() { - let ch = buffer[i]; - if quoting { - if ch == '"' && buffer[i - 1] != '\\' { - quoting = false; - } - continue; - } - match ch { - '"' => quoting = true, - '{' => { - if balances.is_empty() { - start = i; - } - balances.push(ch); - } - '[' => { - if start != 0 { - balances.push(ch); - } - } - '}' => { - balances.pop(); - if balances.is_empty() { - let value: String = buffer[start..=i].iter().collect(); - let value: Value = serde_json::from_str(&value)?; - handler.text(extract_text(&value)?)?; - } - } - ']' => { - balances.pop(); - } - _ => {} - } - } - cursor = buffer.len(); - } + let handle = |value: &str| -> Result<()> { + let value: Value = serde_json::from_str(value)?; + handler.text(extract_text(&value)?)?; + Ok(()) + }; + json_stream(res.bytes_stream(), handle).await?; } Ok(()) }