Fix inmemory stream getting stuck

pull/96/head
Chip Senkbeil 3 years ago
parent 0f6cf3d537
commit 976544eebd
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,5 +1,8 @@
use super::{DataStream, PlainCodec, Transport};
use futures::ready;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
@ -116,51 +119,86 @@ impl AsyncRead for InmemoryStreamReadHalf {
}
// Otherwise, we poll for the next batch to read in
self.rx.poll_recv(cx).map(|x| match x {
match ready!(self.rx.poll_recv(cx)) {
Some(mut x) => {
if x.len() > buf.remaining() {
self.overflow = x.split_off(buf.remaining());
}
buf.put_slice(&x);
Ok(())
Poll::Ready(Ok(()))
}
None => Ok(()),
})
None => Poll::Ready(Ok(())),
}
}
}
/// Write portion of an inmemory channel
#[derive(Debug)]
pub struct InmemoryStreamWriteHalf(mpsc::Sender<Vec<u8>>);
pub struct InmemoryStreamWriteHalf {
tx: Option<mpsc::Sender<Vec<u8>>>,
task: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
}
impl InmemoryStreamWriteHalf {
pub fn new(tx: mpsc::Sender<Vec<u8>>) -> Self {
Self(tx)
Self {
tx: Some(tx),
task: None,
}
}
}
impl fmt::Debug for InmemoryStreamWriteHalf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InmemoryStreamWriteHalf")
.field("tx", &self.tx)
.field(
"task",
&if self.tx.is_some() {
"Some(...)"
} else {
"None"
},
)
.finish()
}
}
impl AsyncWrite for InmemoryStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
use futures::FutureExt;
let n = buf.len();
let f = self.0.send(buf.to_vec()).map(|x| match x {
Ok(_) => Ok(n),
Err(_) => Ok(0),
});
tokio::pin!(f);
f.poll_unpin(cx)
loop {
match self.task.as_mut() {
Some(task) => {
let res = ready!(task.as_mut().poll(cx));
self.task.take();
return Poll::Ready(res);
}
None => match self.tx.as_mut() {
Some(tx) => {
let n = buf.len();
let tx_2 = tx.clone();
let data = buf.to_vec();
let task =
Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) });
self.task.replace(task);
}
None => return Poll::Ready(Ok(0)),
},
}
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx.take();
self.task.take();
Poll::Ready(Ok(()))
}
}

Loading…
Cancel
Save