From 1849609e0734be414e3e8e59f8d8a58aa61ccaec Mon Sep 17 00:00:00 2001 From: Timothy Stack Date: Sun, 9 May 2021 21:44:31 -0700 Subject: [PATCH] [remote] try to enforce protocol state --- src/tailer/tailer.main.c | 77 ++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 19 deletions(-) diff --git a/src/tailer/tailer.main.c b/src/tailer/tailer.main.c index 5230f3ff..64a046fa 100644 --- a/src/tailer/tailer.main.c +++ b/src/tailer/tailer.main.c @@ -219,11 +219,24 @@ void set_client_path_state_error(struct client_path_state *cps, const char *op) delete_client_path_list(&cps->cps_children); } -static int readall(int sock, void *buf, size_t len) +typedef enum { + RS_ERROR, + RS_PACKET_TYPE, + RS_PAYLOAD_TYPE, + RS_PAYLOAD, + RS_PAYLOAD_LENGTH, + RS_PAYLOAD_CONTENT, +} recv_state_t; + +static recv_state_t readall(recv_state_t state, int sock, void *buf, size_t len) { char *cbuf = (char *) buf; off_t offset = 0; + if (state == RS_ERROR) { + return RS_ERROR; + } + while (len > 0) { ssize_t rc = read(sock, &cbuf[offset], len); @@ -233,12 +246,12 @@ static int readall(int sock, void *buf, size_t len) case EINTR: break; default: - return -1; + return RS_ERROR; } } else if (rc == 0) { errno = EIO; - return -1; + return RS_ERROR; } else { len -= rc; @@ -246,20 +259,38 @@ static int readall(int sock, void *buf, size_t len) } } - return 0; + switch (state) { + case RS_PACKET_TYPE: + return RS_PAYLOAD_TYPE; + case RS_PAYLOAD_TYPE: + return RS_PAYLOAD; + case RS_PAYLOAD_LENGTH: + return RS_PAYLOAD_CONTENT; + case RS_PAYLOAD_CONTENT: + return RS_PAYLOAD_TYPE; + default: + return RS_ERROR; + } } -static tailer_packet_payload_type_t read_payload_type(int sock) +static tailer_packet_payload_type_t read_payload_type(recv_state_t *state, int sock) { tailer_packet_payload_type_t retval = TPPT_DONE; - readall(sock, &retval, sizeof(retval)); + assert(*state == RS_PAYLOAD_TYPE); + + *state = readall(*state, sock, &retval, sizeof(retval)); + if (*state != RS_ERROR && retval == TPPT_DONE) { + *state = RS_PACKET_TYPE; + } return retval; } -static char *readstr(int sock) +static char *readstr(recv_state_t *state, int sock) { - tailer_packet_payload_type_t payload_type = read_payload_type(sock); + assert(*state == RS_PAYLOAD_TYPE); + + tailer_packet_payload_type_t payload_type = read_payload_type(state, sock); if (payload_type != TPPT_STRING) { fprintf(stderr, "error: expected string, got: %d\n", payload_type); @@ -268,7 +299,9 @@ static char *readstr(int sock) int32_t length; - if (readall(sock, &length, sizeof(length)) == -1) { + *state = RS_PAYLOAD_LENGTH; + *state = readall(*state, sock, &length, sizeof(length)); + if (*state == RS_ERROR) { fprintf(stderr, "error: unable to read string length\n"); return NULL; } @@ -278,7 +311,8 @@ static char *readstr(int sock) return NULL; } - if (readall(sock, retval, length) == -1) { + *state = readall(*state, sock, retval, length); + if (*state == RS_ERROR) { fprintf(stderr, "error: unable to read string of length: %d\n", length); free(retval); return NULL; @@ -288,16 +322,18 @@ static char *readstr(int sock) return retval; } -static int readint64(int sock, int64_t *i) +static int readint64(recv_state_t *state, int sock, int64_t *i) { - tailer_packet_payload_type_t payload_type = read_payload_type(sock); + tailer_packet_payload_type_t payload_type = read_payload_type(state, sock); if (payload_type != TPPT_INT64) { fprintf(stderr, "error: expected int64, got: %d\n", payload_type); return -1; } - if (readall(sock, i, sizeof(*i)) == -1) { + *state = RS_PAYLOAD_CONTENT; + *state = readall(*state, sock, i, sizeof(*i)); + if (*state == -1) { fprintf(stderr, "error: unable to read int64\n"); return -1; } @@ -581,6 +617,7 @@ int poll_paths(struct list *path_list) int main(int argc, char *argv[]) { int done = 0, timeout = 0; + recv_state_t rstate = RS_PACKET_TYPE; list_init(&client_path_list); @@ -596,7 +633,9 @@ int main(int argc, char *argv[]) if (ready_count) { tailer_packet_type_t type; - if (readall(STDIN_FILENO, &type, sizeof(type)) == -1) { + assert(rstate == RS_PACKET_TYPE); + rstate = readall(rstate, STDIN_FILENO, &type, sizeof(type)); + if (rstate == RS_ERROR) { fprintf(stderr, "info: exiting...\n"); done = 1; } else { @@ -604,11 +643,11 @@ int main(int argc, char *argv[]) case TPT_OPEN_PATH: case TPT_CLOSE_PATH: case TPT_LOAD_PREVIEW: { - const char *path = readstr(STDIN_FILENO); + const char *path = readstr(&rstate, STDIN_FILENO); int64_t preview_id = 0; if (type == TPT_LOAD_PREVIEW) { - if (readint64(STDIN_FILENO, &preview_id) == -1) { + if (readint64(&rstate, STDIN_FILENO, &preview_id) == -1) { done = 1; break; } @@ -616,7 +655,7 @@ int main(int argc, char *argv[]) if (path == NULL) { fprintf(stderr, "error: unable to get path to open\n"); done = 1; - } else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) { + } else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) { fprintf(stderr, "error: invalid open packet\n"); done = 1; } else if (type == TPT_OPEN_PATH) { @@ -773,13 +812,13 @@ int main(int argc, char *argv[]) } case TPT_ACK_BLOCK: case TPT_NEED_BLOCK: { - char *path = readstr(STDIN_FILENO); + char *path = readstr(&rstate, STDIN_FILENO); // fprintf(stderr, "info: block packet path: %s\n", path); if (path == NULL) { fprintf(stderr, "error: unable to get block path\n"); done = 1; - } else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) { + } else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) { fprintf(stderr, "error: invalid block packet\n"); done = 1; } else {