mirror of https://github.com/chipsenkbeil/distant
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
512 lines
18 KiB
Rust
512 lines
18 KiB
Rust
use std::io;
|
|
use std::sync::{Mutex, MutexGuard};
|
|
|
|
use async_trait::async_trait;
|
|
use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
|
|
use tokio::sync::mpsc::{self};
|
|
|
|
use super::{Interest, Ready, Reconnectable, Transport};
|
|
|
|
/// Represents a [`Transport`] comprised of two inmemory channels
|
|
#[derive(Debug)]
|
|
pub struct InmemoryTransport {
|
|
tx: mpsc::Sender<Vec<u8>>,
|
|
rx: Mutex<mpsc::Receiver<Vec<u8>>>,
|
|
|
|
/// Internal storage used when we get more data from a `try_read` than can be returned
|
|
buf: Mutex<Option<Vec<u8>>>,
|
|
}
|
|
|
|
impl InmemoryTransport {
|
|
/// Creates a new transport where `tx` is used to send data out of the transport during
|
|
/// [`try_write`] and `rx` is used to receive data into the transport during [`try_read`].
|
|
///
|
|
/// [`try_read`]: Transport::try_read
|
|
/// [`try_write`]: Transport::try_write
|
|
pub fn new(tx: mpsc::Sender<Vec<u8>>, rx: mpsc::Receiver<Vec<u8>>) -> Self {
|
|
Self {
|
|
tx,
|
|
rx: Mutex::new(rx),
|
|
buf: Mutex::new(None),
|
|
}
|
|
}
|
|
|
|
/// Returns (incoming_tx, outgoing_rx, transport) where `incoming_tx` is used to send data to
|
|
/// the transport where it will be consumed during [`try_read`] and `outgoing_rx` is used to
|
|
/// receive data from the transport when it is written using [`try_write`].
|
|
///
|
|
/// [`try_read`]: Transport::try_read
|
|
/// [`try_write`]: Transport::try_write
|
|
pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
|
|
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
|
|
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
|
|
|
|
(
|
|
incoming_tx,
|
|
outgoing_rx,
|
|
Self::new(outgoing_tx, incoming_rx),
|
|
)
|
|
}
|
|
|
|
/// Returns pair of transports that are connected such that one sends to the other and
|
|
/// vice versa
|
|
pub fn pair(buffer: usize) -> (Self, Self) {
|
|
let (tx, rx, transport) = Self::make(buffer);
|
|
(transport, Self::new(tx, rx))
|
|
}
|
|
|
|
/// Links two independent [`InmemoryTransport`] together by dropping their internal channels
|
|
/// and generating new ones of `buffer` capacity to connect these transports.
|
|
///
|
|
/// ### Note
|
|
///
|
|
/// This will drop any pre-existing data in the internal storage to avoid corruption.
|
|
pub fn link(&mut self, other: &mut InmemoryTransport, buffer: usize) {
|
|
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
|
|
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
|
|
|
|
self.buf = Mutex::new(None);
|
|
self.tx = outgoing_tx;
|
|
self.rx = Mutex::new(incoming_rx);
|
|
|
|
other.buf = Mutex::new(None);
|
|
other.tx = incoming_tx;
|
|
other.rx = Mutex::new(outgoing_rx);
|
|
}
|
|
|
|
/// Returns true if the read channel is closed, meaning it will no longer receive more data.
|
|
/// This does not factor in data remaining in the internal buffer, meaning that this may return
|
|
/// true while the transport still has data remaining in the internal buffer.
|
|
///
|
|
/// NOTE: Because there is no `is_closed` on the receiver, we have to actually try to
|
|
/// read from the receiver to see if it is disconnected, adding any received data
|
|
/// to our internal buffer if it is not disconnected and has data available
|
|
///
|
|
/// Track https://github.com/tokio-rs/tokio/issues/4638 for future `is_closed` on rx
|
|
fn is_rx_closed(&self) -> bool {
|
|
match self.rx.lock().unwrap().try_recv() {
|
|
Ok(mut data) => {
|
|
let mut buf_lock = self.buf.lock().unwrap();
|
|
|
|
let data = match buf_lock.take() {
|
|
Some(mut existing) => {
|
|
existing.append(&mut data);
|
|
existing
|
|
}
|
|
None => data,
|
|
};
|
|
|
|
*buf_lock = Some(data);
|
|
|
|
false
|
|
}
|
|
Err(TryRecvError::Empty) => false,
|
|
Err(TryRecvError::Disconnected) => true,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Reconnectable for InmemoryTransport {
|
|
/// Once the underlying channels have closed, there is no way for this transport to
|
|
/// re-establish those channels; therefore, reconnecting will fail with
|
|
/// [`ErrorKind::ConnectionRefused`] if either underlying channel has closed.
|
|
///
|
|
/// [`ErrorKind::ConnectionRefused`]: io::ErrorKind::ConnectionRefused
|
|
async fn reconnect(&mut self) -> io::Result<()> {
|
|
if self.tx.is_closed() || self.is_rx_closed() {
|
|
Err(io::Error::from(io::ErrorKind::ConnectionRefused))
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Transport for InmemoryTransport {
|
|
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
|
|
// Lock our internal storage to ensure that nothing else mutates it for the lifetime of
|
|
// this call as we want to make sure that data is read and stored in order
|
|
let mut buf_lock = self.buf.lock().unwrap();
|
|
|
|
// Check if we have data in our internal buffer, and if so feed it into the outgoing buf
|
|
if let Some(data) = buf_lock.take() {
|
|
return Ok(copy_and_store(buf_lock, data, buf));
|
|
}
|
|
|
|
match self.rx.lock().unwrap().try_recv() {
|
|
Ok(data) => Ok(copy_and_store(buf_lock, data, buf)),
|
|
Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
|
Err(TryRecvError::Disconnected) => Ok(0),
|
|
}
|
|
}
|
|
|
|
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
|
|
match self.tx.try_send(buf.to_vec()) {
|
|
Ok(()) => Ok(buf.len()),
|
|
Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
|
Err(TrySendError::Closed(_)) => Ok(0),
|
|
}
|
|
}
|
|
|
|
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
|
|
let mut status = Ready::EMPTY;
|
|
|
|
if interest.is_readable() {
|
|
// TODO: Replace `self.is_rx_closed()` with `self.rx.is_closed()` once the tokio issue
|
|
// is resolved that adds `is_closed` to the `mpsc::Receiver`
|
|
//
|
|
// See https://github.com/tokio-rs/tokio/issues/4638
|
|
status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() {
|
|
Ready::READ_CLOSED
|
|
} else {
|
|
Ready::READABLE
|
|
};
|
|
}
|
|
|
|
if interest.is_writable() {
|
|
status |= if self.tx.is_closed() {
|
|
Ready::WRITE_CLOSED
|
|
} else {
|
|
Ready::WRITABLE
|
|
};
|
|
}
|
|
|
|
Ok(status)
|
|
}
|
|
}
|
|
|
|
/// Copies `data` into `out`, storing any overflow from `data` into the storage pointed to by the
|
|
/// mutex `buf_lock`
|
|
fn copy_and_store(
|
|
mut buf_lock: MutexGuard<Option<Vec<u8>>>,
|
|
mut data: Vec<u8>,
|
|
out: &mut [u8],
|
|
) -> usize {
|
|
// NOTE: We can get data that is larger than the destination buf; so,
|
|
// we store as much as we can and queue up the rest in our temporary
|
|
// storage for future retrievals
|
|
if data.len() > out.len() {
|
|
let n = out.len();
|
|
out.copy_from_slice(&data[..n]);
|
|
*buf_lock = Some(data.split_off(n));
|
|
n
|
|
} else {
|
|
let n = data.len();
|
|
out[..n].copy_from_slice(&data);
|
|
n
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use test_log::test;
|
|
|
|
use super::*;
|
|
use crate::common::TransportExt;
|
|
|
|
#[test]
|
|
fn is_rx_closed_should_properly_reflect_if_internal_rx_channel_is_closed() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
// Not closed when the channel is empty
|
|
assert!(!transport.is_rx_closed());
|
|
|
|
read_tx.try_send(b"some bytes".to_vec()).unwrap();
|
|
|
|
// Not closed when the channel has data (will queue up data)
|
|
assert!(!transport.is_rx_closed());
|
|
assert_eq!(
|
|
transport.buf.lock().unwrap().as_deref().unwrap(),
|
|
b"some bytes"
|
|
);
|
|
|
|
// Queue up one more set of bytes and then close the channel
|
|
read_tx.try_send(b"more".to_vec()).unwrap();
|
|
drop(read_tx);
|
|
|
|
// Not closed when channel has closed but has something remaining in the queue
|
|
assert!(!transport.is_rx_closed());
|
|
assert_eq!(
|
|
transport.buf.lock().unwrap().as_deref().unwrap(),
|
|
b"some bytesmore"
|
|
);
|
|
|
|
// Closed once there is nothing left in the channel and it has closed
|
|
assert!(transport.is_rx_closed());
|
|
assert_eq!(
|
|
transport.buf.lock().unwrap().as_deref().unwrap(),
|
|
b"some bytesmore"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn try_read_should_succeed_if_able_to_read_entire_data_through_channel() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
// Queue up some data to be read
|
|
read_tx.try_send(b"some bytes".to_vec()).unwrap();
|
|
|
|
let mut buf = [0; 10];
|
|
assert_eq!(transport.try_read(&mut buf).unwrap(), 10);
|
|
assert_eq!(&buf[..10], b"some bytes");
|
|
}
|
|
|
|
#[test]
|
|
fn try_read_should_succeed_if_reading_cached_data_from_previous_read() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
// Queue up some data to be read
|
|
read_tx.try_send(b"some bytes".to_vec()).unwrap();
|
|
|
|
let mut buf = [0; 5];
|
|
assert_eq!(transport.try_read(&mut buf).unwrap(), 5);
|
|
assert_eq!(&buf[..5], b"some ");
|
|
|
|
// Queue up some new data to be read (previous data already consumed)
|
|
read_tx.try_send(b"more".to_vec()).unwrap();
|
|
|
|
let mut buf = [0; 2];
|
|
assert_eq!(transport.try_read(&mut buf).unwrap(), 2);
|
|
assert_eq!(&buf[..2], b"by");
|
|
|
|
// Inmemory still separates buffered bytes from next channel recv()
|
|
let mut buf = [0; 5];
|
|
assert_eq!(transport.try_read(&mut buf).unwrap(), 3);
|
|
assert_eq!(&buf[..3], b"tes");
|
|
|
|
let mut buf = [0; 5];
|
|
assert_eq!(transport.try_read(&mut buf).unwrap(), 4);
|
|
assert_eq!(&buf[..4], b"more");
|
|
}
|
|
|
|
#[test]
|
|
fn try_read_should_fail_with_would_block_if_channel_is_empty() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
assert_eq!(
|
|
transport.try_read(&mut [0; 5]).unwrap_err().kind(),
|
|
io::ErrorKind::WouldBlock
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn try_read_should_succeed_with_zero_bytes_read_if_channel_closed() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
// Drop to close the read channel
|
|
drop(read_tx);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
assert_eq!(transport.try_read(&mut [0; 5]).unwrap(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn try_write_should_succeed_if_able_to_send_data_through_channel() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
let value = b"some bytes";
|
|
assert_eq!(transport.try_write(value).unwrap(), value.len());
|
|
}
|
|
|
|
#[test]
|
|
fn try_write_should_fail_with_would_block_if_channel_capacity_has_been_reached() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
// Fill up the channel
|
|
transport
|
|
.try_write(b"some bytes")
|
|
.expect("Failed to fill channel");
|
|
|
|
assert_eq!(
|
|
transport.try_write(b"some bytes").unwrap_err().kind(),
|
|
io::ErrorKind::WouldBlock
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn try_write_should_succeed_with_zero_bytes_written_if_channel_closed() {
|
|
let (write_tx, write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
// Drop to close the write channel
|
|
drop(write_rx);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
assert_eq!(transport.try_write(b"some bytes").unwrap(), 0);
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn reconnect_should_fail_if_read_channel_closed() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_, read_rx) = mpsc::channel(1);
|
|
let mut transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
assert_eq!(
|
|
transport.reconnect().await.unwrap_err().kind(),
|
|
io::ErrorKind::ConnectionRefused
|
|
);
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn reconnect_should_fail_if_write_channel_closed() {
|
|
let (write_tx, _) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
let mut transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
assert_eq!(
|
|
transport.reconnect().await.unwrap_err().kind(),
|
|
io::ErrorKind::ConnectionRefused
|
|
);
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn reconnect_should_succeed_if_both_channels_open() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
let mut transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
transport.reconnect().await.unwrap();
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn ready_should_report_read_closed_if_channel_closed_and_internal_buf_empty() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
// Drop to close the read channel
|
|
drop(read_tx);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
let ready = transport.ready(Interest::READABLE).await.unwrap();
|
|
assert!(ready.is_readable());
|
|
assert!(ready.is_read_closed());
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn ready_should_report_readable_if_channel_not_closed() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
let ready = transport.ready(Interest::READABLE).await.unwrap();
|
|
assert!(ready.is_readable());
|
|
assert!(!ready.is_read_closed());
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn ready_should_report_readable_if_internal_buf_not_empty() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
// Drop to close the read channel
|
|
drop(read_tx);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
|
|
// Assign some data to our buffer to ensure that we test this condition
|
|
*transport.buf.lock().unwrap() = Some(vec![1]);
|
|
|
|
let ready = transport.ready(Interest::READABLE).await.unwrap();
|
|
assert!(ready.is_readable());
|
|
assert!(!ready.is_read_closed());
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn ready_should_report_writable_if_channel_not_closed() {
|
|
let (write_tx, _write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
let ready = transport.ready(Interest::WRITABLE).await.unwrap();
|
|
assert!(ready.is_writable());
|
|
assert!(!ready.is_write_closed());
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn ready_should_report_write_closed_if_channel_closed() {
|
|
let (write_tx, write_rx) = mpsc::channel(1);
|
|
let (_read_tx, read_rx) = mpsc::channel(1);
|
|
|
|
// Drop to close the write channel
|
|
drop(write_rx);
|
|
|
|
let transport = InmemoryTransport::new(write_tx, read_rx);
|
|
let ready = transport.ready(Interest::WRITABLE).await.unwrap();
|
|
assert!(ready.is_writable());
|
|
assert!(ready.is_write_closed());
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn make_should_return_sender_that_sends_data_to_transport() {
|
|
let (tx, _, transport) = InmemoryTransport::make(3);
|
|
|
|
tx.send(b"test msg 1".to_vec()).await.unwrap();
|
|
tx.send(b"test msg 2".to_vec()).await.unwrap();
|
|
tx.send(b"test msg 3".to_vec()).await.unwrap();
|
|
|
|
// Should get data matching a singular message
|
|
let mut buf = [0; 256];
|
|
let len = transport.try_read(&mut buf).unwrap();
|
|
assert_eq!(&buf[..len], b"test msg 1");
|
|
|
|
// Next call would get the second message
|
|
let len = transport.try_read(&mut buf).unwrap();
|
|
assert_eq!(&buf[..len], b"test msg 2");
|
|
|
|
// When the last of the senders is dropped, we should still get
|
|
// the rest of the data that was sent first before getting
|
|
// an indicator that there is no more data
|
|
drop(tx);
|
|
|
|
let len = transport.try_read(&mut buf).unwrap();
|
|
assert_eq!(&buf[..len], b"test msg 3");
|
|
|
|
let len = transport.try_read(&mut buf).unwrap();
|
|
assert_eq!(len, 0, "Unexpectedly got more data");
|
|
}
|
|
|
|
#[test(tokio::test)]
|
|
async fn make_should_return_receiver_that_receives_data_from_transport() {
|
|
let (_, mut rx, transport) = InmemoryTransport::make(3);
|
|
|
|
transport.write_all(b"test msg 1").await.unwrap();
|
|
transport.write_all(b"test msg 2").await.unwrap();
|
|
transport.write_all(b"test msg 3").await.unwrap();
|
|
|
|
// Should get data matching a singular message
|
|
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
|
|
|
|
// Next call would get the second message
|
|
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
|
|
|
|
// When the transport is dropped, we should still get
|
|
// the rest of the data that was sent first before getting
|
|
// an indicator that there is no more data
|
|
drop(transport);
|
|
|
|
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
|
|
|
|
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
|
|
}
|
|
}
|