refactor: extract json stream handling (#398)

pull/399/head
sigoden 2 months ago committed by GitHub
parent 5915bc2f3a
commit a0bd6e1d5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,13 +1,12 @@
use super::{
message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, PromptType,
SendData, TokensCountFactors,
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
PromptType, SendData, TokensCountFactors,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -107,55 +106,14 @@ pub(crate) async fn send_message_streaming(
let data: Value = res.json().await?;
check_error(&data)?;
} else {
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
}
match ch {
'"' => quoting = true,
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
let value: Value = serde_json::from_str(&value)?;
if let Some("text-generation") = value["event_type"].as_str() {
handler.text(extract_text(&value)?)?;
}
}
}
']' => {
balances.pop();
}
_ => {}
}
let handle = |value: &str| -> Result<()> {
let value: Value = serde_json::from_str(value)?;
if let Some("text-generation") = value["event_type"].as_str() {
handler.text(extract_text(&value)?)?;
}
cursor = buffer.len();
}
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}

@ -11,6 +11,7 @@ use crate::{
use anyhow::{Context, Result};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt};
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -365,6 +366,68 @@ pub fn patch_system_message(messages: &mut Vec<Message>) {
}
}
pub async fn json_stream<S, F>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
F: FnMut(&str) -> Result<()>,
{
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut escape = false;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '\\' {
escape = !escape;
} else {
if !escape && ch == '"' {
quoting = false;
}
escape = false;
}
continue;
}
match ch {
'"' => {
quoting = true;
escape = false;
}
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
handle(&value)?;
}
}
']' => {
balances.pop();
}
_ => {}
}
}
cursor = buffer.len();
}
Ok(())
}
fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
let segs: Vec<&str> = path.split('.').collect();
match segs.as_slice() {

@ -1,6 +1,6 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, SendData,
TokensCountFactors, VertexAIClient,
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType,
SendData, TokensCountFactors, VertexAIClient,
};
use crate::{render::ReplyHandler, utils::PromptKind};
@ -8,7 +8,6 @@ use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -136,53 +135,12 @@ pub(crate) async fn send_message_streaming(
let data: Value = res.json().await?;
check_error(&data)?;
} else {
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
}
match ch {
'"' => quoting = true,
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
let value: Value = serde_json::from_str(&value)?;
handler.text(extract_text(&value)?)?;
}
}
']' => {
balances.pop();
}
_ => {}
}
}
cursor = buffer.len();
}
let handle = |value: &str| -> Result<()> {
let value: Value = serde_json::from_str(value)?;
handler.text(extract_text(&value)?)?;
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}

Loading…
Cancel
Save