diff --git a/src/client/claude.rs b/src/client/claude.rs index 5c2a912..283ae99 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -78,8 +78,12 @@ impl ClaudeClient { } async fn send_message(builder: RequestBuilder) -> Result { - let data: Value = builder.send().await?.json().await?; - check_error(&data)?; + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if status != 200 { + catch_error(&data, status.as_u16())?; + } let output = data["content"][0]["text"] .as_str() @@ -95,7 +99,6 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand Ok(Event::Open) => {} Ok(Event::Message(message)) => { let data: Value = serde_json::from_str(&message.data)?; - check_error(&data)?; if let Some(typ) = data["type"].as_str() { if typ == "content_block_delta" { if let Some(text) = data["delta"]["text"].as_str() { @@ -107,16 +110,15 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand Err(err) => { match err { EventSourceError::StreamEnded => {} - EventSourceError::InvalidStatusCode(code, res) => { + EventSourceError::InvalidStatusCode(status, res) => { let text = res.text().await?; let data: Value = match text.parse() { Ok(data) => data, Err(_) => { - bail!("Request failed, {code}, {text}"); + bail!("Invalid respoinse, status: {status}, text: {text}"); } }; - check_error(&data)?; - bail!("Request failed, {code}, {text}"); + catch_error(&data, status.as_u16())?; } EventSourceError::InvalidContentType(_, res) => { let text = res.text().await?; @@ -209,13 +211,12 @@ fn build_body(data: SendData, model: &Model) -> Result { Ok(body) } -fn check_error(data: &Value) -> Result<()> { +fn catch_error(data: &Value, status: u16) -> Result<()> { + debug!("Invalid response, status: {status}, data: {data}"); if let Some(error) = data["error"].as_object() { - if let (Some(typ), Some(message)) = (error["type"].as_str(), error["message"].as_str()) { - bail!("{typ}: {message}"); - } else { - bail!("{}", Value::Object(error.clone())) + if let (Some(type_), Some(message)) = (error["type"].as_str(), error["message"].as_str()) { + bail!("{message} (type: {type_})"); } } - Ok(()) + bail!("Invalid response, status: {status}, data: {data}"); } diff --git a/src/client/cohere.rs b/src/client/cohere.rs index db2ad28..4d713b6 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -78,7 +78,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result { let status = res.status(); let data: Value = res.json().await?; if status != 200 { - check_error(&data)?; + catch_error(&data, status.as_u16())?; } let output = extract_text(&data)?; Ok(output.to_string()) @@ -89,9 +89,10 @@ pub(crate) async fn send_message_streaming( handler: &mut ReplyHandler, ) -> Result<()> { let res = builder.send().await?; - if res.status() != 200 { + let status = res.status(); + if status != 200 { let data: Value = res.json().await?; - check_error(&data)?; + catch_error(&data, status.as_u16())?; } else { let handle = |value: &str| -> Result<()> { let value: Value = serde_json::from_str(value)?; @@ -105,24 +106,7 @@ pub(crate) async fn send_message_streaming( Ok(()) } -fn extract_text(data: &Value) -> Result<&str> { - match data["text"].as_str() { - Some(text) => Ok(text), - None => { - bail!("Invalid response data: {data}") - } - } -} - -fn check_error(data: &Value) -> Result<()> { - if let Some(message) = data["message"].as_str() { - bail!("{message}"); - } else { - bail!("Error {}", data); - } -} - -pub(crate) fn build_body(data: SendData, model: &Model) -> Result { +fn build_body(data: SendData, model: &Model) -> Result { let SendData { mut messages, temperature, @@ -195,3 +179,22 @@ pub(crate) fn build_body(data: SendData, model: &Model) -> Result { Ok(body) } + +fn catch_error(data: &Value, status: u16) -> Result<()> { + debug!("Invalid response, status: {status}, data: {data}"); + + if let Some(message) = data["message"].as_str() { + bail!("{message}"); + } else { + bail!("Invalid response, status: {status}, data: {data}"); + } +} + +fn extract_text(data: &Value) -> Result<&str> { + match data["text"].as_str() { + Some(text) => Ok(text), + None => { + bail!("Invalid response data: {data}") + } + } +} diff --git a/src/client/ernie.rs b/src/client/ernie.rs index b0a0087..68d1ace 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -171,7 +171,7 @@ impl ErnieClient { async fn send_message(builder: RequestBuilder) -> Result { let data: Value = builder.send().await?.json().await?; - check_error(&data)?; + catch_error(&data)?; let output = data["result"] .as_str() @@ -199,7 +199,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand .map_err(|_| anyhow!("Invalid response header"))?; if content_type.contains("application/json") { let data: Value = res.json().await?; - check_error(&data)?; + catch_error(&data)?; bail!("Request failed"); } else { let text = res.text().await?; @@ -226,20 +226,6 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand Ok(()) } -fn check_error(data: &Value) -> Result<()> { - if let Some(err_msg) = data["error_msg"].as_str() { - if let Some(code) = data["error_code"].as_number().and_then(|v| v.as_u64()) { - if code == 110 { - *ACCESS_TOKEN.lock().unwrap() = None; - } - bail!("{err_msg}. err_code: {code}"); - } else { - bail!("{err_msg}"); - } - } - Ok(()) -} - fn build_body(data: SendData, model: &Model) -> Value { let SendData { mut messages, @@ -268,6 +254,20 @@ fn build_body(data: SendData, model: &Model) -> Value { body } +fn catch_error(data: &Value) -> Result<()> { + if let (Some(error_code), Some(error_msg)) = + (data["error_code"].as_number(), data["error_msg"].as_str()) + { + debug!("Invalid response: {}", data); + let error_code = error_code.as_i64().unwrap_or_default(); + if error_code == 110 { + *ACCESS_TOKEN.lock().unwrap() = None; + } + bail!("{error_msg} (error_code: {error_code})"); + } + Ok(()) +} + async fn fetch_access_token( client: &reqwest::Client, api_key: &str, diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 403712f..055730a 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -82,11 +82,10 @@ impl OllamaClient { async fn send_message(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); + let data = res.json().await?; if status != 200 { - let text = res.text().await?; - bail!("{status}, {text}"); + catch_error(&data, status.as_u16())?; } - let data: Value = res.json().await?; let output = data["message"]["content"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; @@ -97,8 +96,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand let res = builder.send().await?; let status = res.status(); if status != 200 { - let text = res.text().await?; - bail!("{status}, {text}"); + let data = res.json().await?; + catch_error(&data, status.as_u16())?; } else { let mut stream = res.bytes_stream(); while let Some(chunk) = stream.next().await { @@ -189,3 +188,11 @@ fn build_body(data: SendData, model: &Model) -> Result { Ok(body) } + +fn catch_error(data: &Value, status: u16) -> Result<()> { + debug!("Invalid response, status: {status}, data: {data}"); + if let Some(error) = data["error"].as_str() { + bail!("{error}"); + } + bail!("Invalid response, status: {status}, data: {data}"); +} diff --git a/src/client/openai.rs b/src/client/openai.rs index 7925ec6..f35c6ab 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -66,9 +66,11 @@ impl OpenAIClient { } pub async fn openai_send_message(builder: RequestBuilder) -> Result { - let data: Value = builder.send().await?.json().await?; - if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("{err_msg}"); + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if status != 200 { + catch_error(&data, status.as_u16())?; } let output = data["choices"][0]["message"]["content"] @@ -97,21 +99,15 @@ pub async fn openai_send_message_streaming( } Err(err) => { match err { - EventSourceError::InvalidStatusCode(code, res) => { + EventSourceError::InvalidStatusCode(status, res) => { let text = res.text().await?; let data: Value = match text.parse() { Ok(data) => data, Err(_) => { - bail!("Request failed, {code}, {text}"); + bail!("Invalid respoinse, status: {status}, text: {text}"); } }; - if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("{err_msg}"); - } else if let Some(err_msg) = data["message"].as_str() { - bail!("{err_msg}"); - } else { - bail!("Request failed, {code}, {text}"); - } + catch_error(&data, status.as_u16())?; } EventSourceError::StreamEnded => {} EventSourceError::InvalidContentType(_, res) => { @@ -156,3 +152,15 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value { } body } + +fn catch_error(data: &Value, status: u16) -> Result<()> { + debug!("Invalid response, status: {status}, data: {data}"); + if let Some(error) = data["error"].as_object() { + if let (Some(type_), Some(message)) = (error["type"].as_str(), error["message"].as_str()) { + bail!("{message} (type: {type_})"); + } + } else if let Some(message) = data["message"].as_str() { + bail!("{message}"); + } + bail!("Invalid response, status: {status}, data: {data}"); +} diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index a3ee988..fb97964 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -111,7 +111,7 @@ impl QianwenClient { async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result { let data: Value = builder.send().await?.json().await?; - check_error(&data)?; + catch_error(&data)?; let output = if is_vl { data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() @@ -137,7 +137,7 @@ async fn send_message_streaming( Ok(Event::Open) => {} Ok(Event::Message(message)) => { let data: Value = serde_json::from_str(&message.data)?; - check_error(&data)?; + catch_error(&data)?; if is_vl { let text = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str(); @@ -165,13 +165,6 @@ async fn send_message_streaming( Ok(()) } -fn check_error(data: &Value) -> Result<()> { - if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) { - bail!("{code}: {message}"); - } - Ok(()) -} - fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool)> { let SendData { messages, @@ -243,6 +236,14 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool Ok((body, has_upload)) } +fn catch_error(data: &Value) -> Result<()> { + if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) { + debug!("Invalid response: {}", data); + bail!("{message} (code: {code})"); + } + Ok(()) +} + /// Patch messsages, upload embedded images to oss async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec) -> Result<()> { for message in messages { diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 66bb098..d9c1019 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -108,7 +108,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result { let status = res.status(); let data: Value = res.json().await?; if status != 200 { - check_error(&data)?; + catch_error(&data, status.as_u16())?; } let output = extract_text(&data)?; Ok(output.to_string()) @@ -119,9 +119,10 @@ pub(crate) async fn send_message_streaming( handler: &mut ReplyHandler, ) -> Result<()> { let res = builder.send().await?; - if res.status() != 200 { + let status = res.status(); + if status != 200 { let data: Value = res.json().await?; - check_error(&data)?; + catch_error(&data, status.as_u16())?; } else { let handle = |value: &str| -> Result<()> { let value: Value = serde_json::from_str(value)?; @@ -149,22 +150,6 @@ fn extract_text(data: &Value) -> Result<&str> { } } -fn check_error(data: &Value) -> Result<()> { - if let Some((Some(status), Some(message))) = data[0]["error"].as_object().map(|v| { - ( - v.get("status").and_then(|v| v.as_str()), - v.get("message").and_then(|v| v.as_str()), - ) - }) { - if status == "UNAUTHENTICATED" { - unsafe { ACCESS_TOKEN = (String::new(), 0) } - } - bail!("{status}: {message}") - } else { - bail!("Error {}", data); - } -} - pub(crate) fn build_body( data: SendData, model: &Model, @@ -241,6 +226,24 @@ pub(crate) fn build_body( Ok(body) } +fn catch_error(data: &Value, status: u16) -> Result<()> { + debug!("Invalid response, status: {status}, data: {data}"); + + if let Some((Some(status), Some(message))) = data[0]["error"].as_object().map(|v| { + ( + v.get("status").and_then(|v| v.as_str()), + v.get("message").and_then(|v| v.as_str()), + ) + }) { + if status == "UNAUTHENTICATED" { + unsafe { ACCESS_TOKEN = (String::new(), 0) } + } + bail!("{message} (status: {status})") + } else { + bail!("Invalid response, status: {status}, data: {data}",); + } +} + async fn fetch_access_token( client: &reqwest::Client, file: &Option,