refactor: gemini/vertexai blocked by safety (#386)

pull/388/head
sigoden 3 months ago committed by GitHub
parent f3551c2efa
commit b91d1c4a93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -125,9 +125,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
if status != 200 {
check_error(&data)?;
}
let output = data["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let output = extract_text(&data)?;
Ok(output.to_string())
}
@ -176,13 +174,7 @@ pub(crate) async fn send_message_streaming(
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
let value: Value = serde_json::from_str(&value)?;
if let Some(text) =
value["candidates"][0]["content"]["parts"][0]["text"].as_str()
{
handler.text(text)?;
} else {
bail!("Invalid response data: {value}")
}
handler.text(extract_text(&value)?)?;
}
}
']' => {
@ -197,6 +189,22 @@ pub(crate) async fn send_message_streaming(
Ok(())
}
fn extract_text(data: &Value) -> Result<&str> {
match data["candidates"][0]["content"]["parts"][0]["text"].as_str() {
Some(text) => Ok(text),
None => {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked by safety settingsconsider ajusting `block_threshold` in the client configuraion")
} else {
bail!("Invalid response data: {data}")
}
}
}
}
fn check_error(data: &Value) -> Result<()> {
if let Some((Some(status), Some(message))) = data[0]["error"].as_object().map(|v| {
(

Loading…
Cancel
Save