diff --git a/src/net/codec.rs b/src/net/codec.rs index 71460db..316c824 100644 --- a/src/net/codec.rs +++ b/src/net/codec.rs @@ -1,17 +1,19 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use derive_more::{Display, Error, From}; +use std::convert::TryInto; use tokio::io; use tokio_util::codec::{Decoder, Encoder}; /// Represents a marker to indicate the beginning of the next message -static MSG_START: &'static [u8] = b";start;"; +static MSG_MARKER: &'static [u8] = b";msg;"; -/// Represents a marker to indicate the end of the next message -static MSG_END: &'static [u8] = b";end;"; +/// Total size in bytes that is used for storing length +static LEN_SIZE: usize = 8; #[inline] -fn packet_size(msg_size: usize) -> usize { - MSG_START.len() + msg_size + MSG_END.len() +fn frame_size(msg_size: usize) -> usize { + // MARKER + u64 (8 bytes) + msg size + MSG_MARKER.len() + LEN_SIZE + msg_size } /// Possible errors that can occur during encoding and decoding @@ -29,11 +31,11 @@ impl<'a> Encoder<&'a [u8]> for DistantCodec { type Error = DistantCodecError; fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> { - // Add our full packet to the bytes - dst.reserve(packet_size(item.len())); - dst.put(MSG_START); + // Add our full frame to the bytes + dst.reserve(frame_size(item.len())); + dst.put(MSG_MARKER); + dst.put_u64(item.len() as u64); dst.put(item); - dst.put(MSG_END); Ok(()) } @@ -46,37 +48,34 @@ impl Decoder for DistantCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { // First, check if we have more data than just our markers, if not we say that it's okay // but that we're waiting - if src.len() <= (MSG_START.len() + MSG_END.len()) { + if src.len() <= (MSG_MARKER.len() + LEN_SIZE) { return Ok(None); } // Second, verify that our first N bytes match our start marker - let marker_start = &src[..MSG_START.len()]; - if marker_start != MSG_START { + let marker_start = &src[..MSG_MARKER.len()]; + if marker_start != MSG_MARKER { return Err(DistantCodecError::CorruptMarker(Bytes::copy_from_slice( marker_start, ))); } - // Third, find end of message marker by scanning the available bytes, and - // consume a full packet of bytes - let mut maybe_frame = None; - for i in (MSG_START.len() + 1)..(src.len() - MSG_END.len()) { - let marker_end = &src[i..(i + MSG_END.len())]; - if marker_end == MSG_END { - maybe_frame = Some(src.split_to(i + MSG_END.len())); - break; - } - } + // Third, retrieve total size of our msg + let msg_len = u64::from_be_bytes( + src[MSG_MARKER.len()..MSG_MARKER.len() + LEN_SIZE] + .try_into() + .unwrap(), + ); // Fourth, return our msg if it's available, stripping it of the start and end markers - if let Some(frame) = maybe_frame { - let data = &frame[MSG_START.len()..(frame.len() - MSG_END.len())]; + let frame_len = frame_size(msg_len as usize); + if src.len() >= frame_len { + let data = src[MSG_MARKER.len() + LEN_SIZE..frame_len].to_vec(); // Advance so frame is no longer kept around - src.advance(frame.len()); + src.advance(frame_len); - Ok(Some(data.to_vec())) + Ok(Some(data)) } else { Ok(None) } diff --git a/src/net/mod.rs b/src/net/mod.rs index 4d03f35..9cf4f61 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -46,7 +46,9 @@ impl Transport { /// Sends some data across the wire pub async fn send(&mut self, data: T) -> Result<(), TransportError> { // Serialize, encrypt, and then (TODO) sign - let data = serde_cbor::ser::to_vec_packed(&data)?; + // NOTE: Cannot used packed implementation for now due to issues with deserialization + // let data = serde_cbor::ser::to_vec_packed(&data)?; + let data = serde_cbor::to_vec(&data)?; let data = aead::seal(&self.key, &data)?; self.inner diff --git a/src/subcommand/execute.rs b/src/subcommand/execute.rs index ddb7ccc..3b36f62 100644 --- a/src/subcommand/execute.rs +++ b/src/subcommand/execute.rs @@ -33,7 +33,5 @@ async fn run_async(cmd: ExecuteSubcommand) -> Result<(), Error> { println!("RESPONSE: {:?}", response); } - println!("DONE"); - Ok(()) }