|
|
|
@ -1,16 +1,14 @@
|
|
|
|
|
use super::{
|
|
|
|
|
maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig,
|
|
|
|
|
Model, ModelConfig, PromptType, SendData, SseHandler,
|
|
|
|
|
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
|
|
|
|
|
ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use crate::utils::PromptKind;
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use chrono::Utc;
|
|
|
|
|
use futures_util::StreamExt;
|
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
|
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
|
|
|
|
use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
|
use std::env;
|
|
|
|
@ -108,49 +106,15 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
|
|
|
|
|
let mut es = builder.eventsource()?;
|
|
|
|
|
while let Some(event) = es.next().await {
|
|
|
|
|
match event {
|
|
|
|
|
Ok(Event::Open) => {}
|
|
|
|
|
Ok(Event::Message(message)) => {
|
|
|
|
|
let data: Value = serde_json::from_str(&message.data)?;
|
|
|
|
|
let handle = |data: &str| -> Result<bool> {
|
|
|
|
|
let data: Value = serde_json::from_str(data)?;
|
|
|
|
|
if let Some(text) = data["result"].as_str() {
|
|
|
|
|
handler.text(text)?;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Err(err) => {
|
|
|
|
|
match err {
|
|
|
|
|
EventSourceError::InvalidContentType(header_value, res) => {
|
|
|
|
|
let content_type = header_value
|
|
|
|
|
.to_str()
|
|
|
|
|
.map_err(|_| anyhow!("Invalid response header"))?;
|
|
|
|
|
if content_type.contains("application/json") {
|
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
|
maybe_catch_error(&data)?;
|
|
|
|
|
bail!("Invalid response data: {data}");
|
|
|
|
|
} else {
|
|
|
|
|
let text = res.text().await?;
|
|
|
|
|
if let Some(text) = text.strip_prefix("data: ") {
|
|
|
|
|
let data: Value = serde_json::from_str(text)?;
|
|
|
|
|
if let Some(text) = data["result"].as_str() {
|
|
|
|
|
handler.text(text)?;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
bail!("Invalid response data: {text}")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
EventSourceError::StreamEnded => {}
|
|
|
|
|
_ => {
|
|
|
|
|
bail!("{}", err);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
es.close();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(false)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
sse_stream(builder, handle).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_body(data: SendData, model: &Model) -> Value {
|
|
|
|
|