diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 37a3667..a650f40 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,10 +1,9 @@ use super::{ - catch_error, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, - OllamaClient, PromptAction, PromptKind, SendData, SseHandler, + catch_error, json_stream, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData, + ModelPatches, OllamaClient, PromptAction, PromptKind, SendData, SseHandler, }; use anyhow::{anyhow, bail, Result}; -use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -81,14 +80,10 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle let data = res.json().await?; catch_error(&data, status.as_u16())?; } else { - let mut stream = res.bytes_stream(); - while let Some(chunk) = stream.next().await { - let chunk = chunk?; - if chunk.is_empty() { - continue; - } - let data: Value = serde_json::from_slice(&chunk)?; + let handle = |message: &str| -> Result<()> { + let data: Value = serde_json::from_str(message)?; debug!("stream-data: {data}"); + if data["done"].is_boolean() { if let Some(text) = data["message"]["content"].as_str() { handler.text(text)?; @@ -96,8 +91,13 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle } else { bail!("Invalid response data: {data}") } - } + + Ok(()) + }; + + json_stream(res.bytes_stream(), handle).await?; } + Ok(()) }