mirror of
https://github.com/sigoden/aichat
synced 2024-11-08 13:10:28 +00:00
refactor: use reqwest-eventsource as sse client
This commit is contained in:
parent
7f2210dbca
commit
8d76fc77fb
26
Cargo.lock
generated
26
Cargo.lock
generated
@ -28,7 +28,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aichat"
|
||||
version = "0.9.0-rc4"
|
||||
version = "0.9.0-rc5"
|
||||
dependencies = [
|
||||
"ansi_colours",
|
||||
"anyhow",
|
||||
@ -44,7 +44,6 @@ dependencies = [
|
||||
"crossterm 0.26.1",
|
||||
"ctrlc",
|
||||
"dirs",
|
||||
"eventsource-stream",
|
||||
"fancy-regex",
|
||||
"futures-util",
|
||||
"inquire",
|
||||
@ -54,6 +53,7 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"reedline",
|
||||
"reqwest",
|
||||
"reqwest-eventsource",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@ -681,6 +681,12 @@ version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.28"
|
||||
@ -1419,6 +1425,22 @@ dependencies = [
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest-eventsource"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f529a5ff327743addc322af460761dff5b50e0c826b9e6ac44c3195c50bb2026"
|
||||
dependencies = [
|
||||
"eventsource-stream",
|
||||
"futures-core",
|
||||
"futures-timer",
|
||||
"mime",
|
||||
"nom",
|
||||
"pin-project-lite",
|
||||
"reqwest",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rgb"
|
||||
version = "0.8.37"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "aichat"
|
||||
version = "0.9.0-rc4"
|
||||
version = "0.9.0-rc5"
|
||||
edition = "2021"
|
||||
authors = ["sigoden <sigoden@gmail.com>"]
|
||||
description = "Use ChatGPT, LocalAI and other LLMs in the terminal."
|
||||
@ -15,7 +15,6 @@ anyhow = "1.0.69"
|
||||
bytes = "1.4.0"
|
||||
clap = { version = "4.1.8", features = ["derive", "string"] }
|
||||
dirs = "5.0.0"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures-util = "0.3.26"
|
||||
inquire = "0.6.2"
|
||||
is-terminal = "0.4.9"
|
||||
@ -40,6 +39,7 @@ arboard = { version = "3.2.0", default-features = false }
|
||||
async-trait = "0.1.74"
|
||||
textwrap = "0.16.0"
|
||||
ansi_colours = "1.2.2"
|
||||
reqwest-eventsource = "0.5.0"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.11.14"
|
||||
|
@ -5,9 +5,9 @@ use crate::repl::ReplyStreamHandler;
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::env;
|
||||
@ -104,23 +104,35 @@ pub async fn openai_send_message_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
if !res.status().is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("{err_msg}");
|
||||
}
|
||||
bail!("Request failed");
|
||||
}
|
||||
let mut stream = res.bytes_stream().eventsource();
|
||||
while let Some(part) = stream.next().await {
|
||||
let chunk = part?.data;
|
||||
if chunk == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let data: Value = serde_json::from_str(&chunk)?;
|
||||
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
handler.text(text)?;
|
||||
let mut es = builder.eventsource()?;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Open) => {}
|
||||
Ok(Event::Message(message)) => {
|
||||
if message.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
match err {
|
||||
EventSourceError::InvalidStatusCode(_, res) => {
|
||||
let data: Value = res.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("{err_msg}");
|
||||
}
|
||||
bail!("Request failed");
|
||||
}
|
||||
EventSourceError::StreamEnded => {}
|
||||
_ => {
|
||||
bail!("{}", err);
|
||||
}
|
||||
}
|
||||
es.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ pub fn render_stream(
|
||||
};
|
||||
if let Err(err) = run() {
|
||||
let err = format!("{err:?}");
|
||||
print_now!("{}\n\n", err.trim());
|
||||
print_now!("\n{}\n\n", err.trim());
|
||||
}
|
||||
drop(wg);
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user