From 68882ecd4dced38e92b6ef581a30e45358fc61e0 Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 29 Apr 2024 08:43:23 +0800 Subject: [PATCH] refactor: abstract event stream handling (#458) --- src/client/claude.rs | 56 +++++++++------------------------------- src/client/cohere.rs | 11 +++----- src/client/common.rs | 48 +++++++++++++++++++++++++++++++++++ src/client/ernie.rs | 56 ++++++++-------------------------------- src/client/openai.rs | 59 ++++++++++--------------------------------- src/client/qianwen.rs | 47 +++++++++++----------------------- 6 files changed, 102 insertions(+), 175 deletions(-) diff --git a/src/client/claude.rs b/src/client/claude.rs index 2b770dc..72ed405 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,14 +1,13 @@ use super::{ - catch_error, extract_system_message, ClaudeClient, CompletionDetails, ExtraConfig, ImageUrl, - MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, SseHandler, + catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig, + ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, + SseHandler, }; use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; -use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; -use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; @@ -68,50 +67,19 @@ pub async fn claude_send_message_streaming( builder: RequestBuilder, handler: &mut SseHandler, ) -> Result<()> { - let mut es = builder.eventsource()?; - while let Some(event) = es.next().await { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let data: Value = serde_json::from_str(&message.data)?; - if let Some(typ) = data["type"].as_str() { - if typ == "content_block_delta" { - if let Some(text) = data["delta"]["text"].as_str() { - handler.text(text)?; - } - } - } - } - Err(err) => { - match err { - EventSourceError::StreamEnded => {} - EventSourceError::InvalidStatusCode(status, res) => { - let text = res.text().await?; - let data: Value = match text.parse() { - Ok(data) => data, - Err(_) => { - bail!( - "Invalid response data: {text} (status: {})", - status.as_u16() - ); - } - }; - catch_error(&data, status.as_u16())?; - } - EventSourceError::InvalidContentType(_, res) => { - let text = res.text().await?; - bail!("The API server should return data as 'text/event-stream', but it isn't. Check the client config. {text}"); - } - _ => { - bail!("{}", err); - } + let handle = |data: &str| -> Result { + let data: Value = serde_json::from_str(data)?; + if let Some(typ) = data["type"].as_str() { + if typ == "content_block_delta" { + if let Some(text) = data["delta"]["text"].as_str() { + handler.text(text)?; } - es.close(); } } - } + Ok(false) + }; - Ok(()) + sse_stream(builder, handle).await } pub fn claude_build_body(data: SendData, model: &Model) -> Result { diff --git a/src/client/cohere.rs b/src/client/cohere.rs index de2a11c..8586b9d 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -28,7 +28,7 @@ impl CohereClient { [("api_key", "API Key:", false, PromptKind::String)]; fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { - let api_key = self.get_api_key().ok(); + let api_key = self.get_api_key()?; let body = build_body(data, &self.model)?; @@ -36,10 +36,7 @@ impl CohereClient { debug!("Cohere Request: {url} {body}"); - let mut builder = client.post(url).json(&body); - if let Some(api_key) = api_key { - builder = builder.bearer_auth(api_key); - } + let builder = client.post(url).bearer_auth(api_key).json(&body); Ok(builder) } @@ -55,7 +52,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta catch_error(&data, status.as_u16())?; } - cohere_extract_completion(&data) + extract_completion(&data) } async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> { @@ -156,7 +153,7 @@ fn build_body(data: SendData, model: &Model) -> Result { Ok(body) } -fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { +fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["text"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; diff --git a/src/client/common.rs b/src/client/common.rs index 64e32ff..e35e956 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -11,6 +11,7 @@ use async_trait::async_trait; use futures_util::{Stream, StreamExt}; use lazy_static::lazy_static; use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder}; +use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; use std::{env, future::Future, time::Duration}; @@ -531,6 +532,53 @@ pub fn maybe_catch_error(data: &Value) -> Result<()> { Ok(()) } +pub async fn sse_stream(builder: RequestBuilder, mut handle: F) -> Result<()> +where + F: FnMut(&str) -> Result, +{ + let mut es = builder.eventsource()?; + while let Some(event) = es.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + if handle(&message.data)? { + break; + } + } + Err(err) => { + match err { + EventSourceError::StreamEnded => {} + EventSourceError::InvalidStatusCode(status, res) => { + let text = res.text().await?; + let data: Value = match text.parse() { + Ok(data) => data, + Err(_) => { + bail!( + "Invalid response data: {text} (status: {})", + status.as_u16() + ); + } + }; + catch_error(&data, status.as_u16())?; + } + EventSourceError::InvalidContentType(header_value, res) => { + let text = res.text().await?; + bail!( + "Invalid response event-stream. content-type: {}, data: {text}", + header_value.to_str().unwrap_or_default() + ); + } + _ => { + bail!("{}", err); + } + } + es.close(); + } + } + } + Ok(()) +} + pub async fn json_stream(mut stream: S, mut handle: F) -> Result<()> where S: Stream> + Unpin, diff --git a/src/client/ernie.rs b/src/client/ernie.rs index cbfdf22..4038e0d 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,16 +1,14 @@ use super::{ - maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig, - Model, ModelConfig, PromptType, SendData, SseHandler, + maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient, + ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use chrono::Utc; -use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; -use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; use std::env; @@ -108,49 +106,15 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta } async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> { - let mut es = builder.eventsource()?; - while let Some(event) = es.next().await { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let data: Value = serde_json::from_str(&message.data)?; - if let Some(text) = data["result"].as_str() { - handler.text(text)?; - } - } - Err(err) => { - match err { - EventSourceError::InvalidContentType(header_value, res) => { - let content_type = header_value - .to_str() - .map_err(|_| anyhow!("Invalid response header"))?; - if content_type.contains("application/json") { - let data: Value = res.json().await?; - maybe_catch_error(&data)?; - bail!("Invalid response data: {data}"); - } else { - let text = res.text().await?; - if let Some(text) = text.strip_prefix("data: ") { - let data: Value = serde_json::from_str(text)?; - if let Some(text) = data["result"].as_str() { - handler.text(text)?; - } - } else { - bail!("Invalid response data: {text}") - } - } - } - EventSourceError::StreamEnded => {} - _ => { - bail!("{}", err); - } - } - es.close(); - } + let handle = |data: &str| -> Result { + let data: Value = serde_json::from_str(data)?; + if let Some(text) = data["result"].as_str() { + handler.text(text)?; } - } + Ok(false) + }; - Ok(()) + sse_stream(builder, handle).await } fn build_body(data: SendData, model: &Model) -> Value { diff --git a/src/client/openai.rs b/src/client/openai.rs index 4c4eae2..ad833c7 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,14 +1,12 @@ use super::{ - catch_error, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, - SendData, SseHandler, + catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, + PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; -use anyhow::{anyhow, bail, Result}; -use futures_util::StreamExt; +use anyhow::{anyhow, Result}; use reqwest::{Client as ReqwestClient, RequestBuilder}; -use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; @@ -67,49 +65,18 @@ pub async fn openai_send_message_streaming( builder: RequestBuilder, handler: &mut SseHandler, ) -> Result<()> { - let mut es = builder.eventsource()?; - while let Some(event) = es.next().await { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - if message.data == "[DONE]" { - break; - } - let data: Value = serde_json::from_str(&message.data)?; - if let Some(text) = data["choices"][0]["delta"]["content"].as_str() { - handler.text(text)?; - } - } - Err(err) => { - match err { - EventSourceError::InvalidStatusCode(status, res) => { - let text = res.text().await?; - let data: Value = match text.parse() { - Ok(data) => data, - Err(_) => { - bail!( - "Invalid response data: {text} (status: {})", - status.as_u16() - ); - } - }; - catch_error(&data, status.as_u16())?; - } - EventSourceError::StreamEnded => {} - EventSourceError::InvalidContentType(_, res) => { - let text = res.text().await?; - bail!("The API server should return data as 'text/event-stream', but it isn't. Check the client config. {text}"); - } - _ => { - bail!("{}", err); - } - } - es.close(); - } + let handle = |data: &str| -> Result { + if data == "[DONE]" { + return Ok(true); } - } + let data: Value = serde_json::from_str(data)?; + if let Some(text) = data["choices"][0]["delta"]["content"].as_str() { + handler.text(text)?; + } + Ok(false) + }; - Ok(()) + sse_stream(builder, handle).await } pub fn openai_build_body(data: SendData, model: &Model) -> Value { diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 1f23385..c16aeaa 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,6 +1,6 @@ use super::{ - maybe_catch_error, message::*, Client, CompletionDetails, ExtraConfig, Model, ModelConfig, - PromptType, QianwenClient, SendData, SseHandler, + maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model, + ModelConfig, PromptType, QianwenClient, SendData, SseHandler, }; use crate::utils::{sha256sum, PromptKind}; @@ -8,12 +8,10 @@ use crate::utils::{sha256sum, PromptKind}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use base64::{engine::general_purpose::STANDARD, Engine}; -use futures_util::StreamExt; use reqwest::{ multipart::{Form, Part}, Client as ReqwestClient, RequestBuilder, }; -use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; use std::borrow::BorrowMut; @@ -109,37 +107,22 @@ async fn send_message_streaming( handler: &mut SseHandler, is_vl: bool, ) -> Result<()> { - let mut es = builder.eventsource()?; - - while let Some(event) = es.next().await { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let data: Value = serde_json::from_str(&message.data)?; - maybe_catch_error(&data)?; - if is_vl { - if let Some(text) = - data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() - { - handler.text(text)?; - } - } else if let Some(text) = data["output"]["text"].as_str() { - handler.text(text)?; - } - } - Err(err) => { - match err { - EventSourceError::StreamEnded => {} - _ => { - bail!("{}", err); - } - } - es.close(); + let handle = |data: &str| -> Result { + let data: Value = serde_json::from_str(data)?; + maybe_catch_error(&data)?; + if is_vl { + if let Some(text) = + data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() + { + handler.text(text)?; } + } else if let Some(text) = data["output"]["text"].as_str() { + handler.text(text)?; } - } + Ok(false) + }; - Ok(()) + sse_stream(builder, handle).await } fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool)> {