diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 80e58e7..7680f96 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -125,9 +125,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result { if status != 200 { check_error(&data)?; } - let output = data["candidates"][0]["content"]["parts"][0]["text"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; + let output = extract_text(&data)?; Ok(output.to_string()) } @@ -176,13 +174,7 @@ pub(crate) async fn send_message_streaming( if balances.is_empty() { let value: String = buffer[start..=i].iter().collect(); let value: Value = serde_json::from_str(&value)?; - if let Some(text) = - value["candidates"][0]["content"]["parts"][0]["text"].as_str() - { - handler.text(text)?; - } else { - bail!("Invalid response data: {value}") - } + handler.text(extract_text(&value)?)?; } } ']' => { @@ -197,6 +189,22 @@ pub(crate) async fn send_message_streaming( Ok(()) } +fn extract_text(data: &Value) -> Result<&str> { + match data["candidates"][0]["content"]["parts"][0]["text"].as_str() { + Some(text) => Ok(text), + None => { + if let Some("SAFETY") = data["promptFeedback"]["blockReason"] + .as_str() + .or_else(|| data["candidates"][0]["finishReason"].as_str()) + { + bail!("Blocked by safety settings,consider ajusting `block_threshold` in the client configuraion") + } else { + bail!("Invalid response data: {data}") + } + } + } +} + fn check_error(data: &Value) -> Result<()> { if let Some((Some(status), Some(message))) = data[0]["error"].as_object().map(|v| { (