[remote] try to enforce protocol state

This commit is contained in:
Timothy Stack 2021-05-09 21:44:31 -07:00
parent 681f771bb7
commit 1849609e07

View File

@ -219,11 +219,24 @@ void set_client_path_state_error(struct client_path_state *cps, const char *op)
delete_client_path_list(&cps->cps_children); delete_client_path_list(&cps->cps_children);
} }
static int readall(int sock, void *buf, size_t len) typedef enum {
RS_ERROR,
RS_PACKET_TYPE,
RS_PAYLOAD_TYPE,
RS_PAYLOAD,
RS_PAYLOAD_LENGTH,
RS_PAYLOAD_CONTENT,
} recv_state_t;
static recv_state_t readall(recv_state_t state, int sock, void *buf, size_t len)
{ {
char *cbuf = (char *) buf; char *cbuf = (char *) buf;
off_t offset = 0; off_t offset = 0;
if (state == RS_ERROR) {
return RS_ERROR;
}
while (len > 0) { while (len > 0) {
ssize_t rc = read(sock, &cbuf[offset], len); ssize_t rc = read(sock, &cbuf[offset], len);
@ -233,12 +246,12 @@ static int readall(int sock, void *buf, size_t len)
case EINTR: case EINTR:
break; break;
default: default:
return -1; return RS_ERROR;
} }
} }
else if (rc == 0) { else if (rc == 0) {
errno = EIO; errno = EIO;
return -1; return RS_ERROR;
} }
else { else {
len -= rc; len -= rc;
@ -246,20 +259,38 @@ static int readall(int sock, void *buf, size_t len)
} }
} }
return 0; switch (state) {
case RS_PACKET_TYPE:
return RS_PAYLOAD_TYPE;
case RS_PAYLOAD_TYPE:
return RS_PAYLOAD;
case RS_PAYLOAD_LENGTH:
return RS_PAYLOAD_CONTENT;
case RS_PAYLOAD_CONTENT:
return RS_PAYLOAD_TYPE;
default:
return RS_ERROR;
}
} }
static tailer_packet_payload_type_t read_payload_type(int sock) static tailer_packet_payload_type_t read_payload_type(recv_state_t *state, int sock)
{ {
tailer_packet_payload_type_t retval = TPPT_DONE; tailer_packet_payload_type_t retval = TPPT_DONE;
readall(sock, &retval, sizeof(retval)); assert(*state == RS_PAYLOAD_TYPE);
*state = readall(*state, sock, &retval, sizeof(retval));
if (*state != RS_ERROR && retval == TPPT_DONE) {
*state = RS_PACKET_TYPE;
}
return retval; return retval;
} }
static char *readstr(int sock) static char *readstr(recv_state_t *state, int sock)
{ {
tailer_packet_payload_type_t payload_type = read_payload_type(sock); assert(*state == RS_PAYLOAD_TYPE);
tailer_packet_payload_type_t payload_type = read_payload_type(state, sock);
if (payload_type != TPPT_STRING) { if (payload_type != TPPT_STRING) {
fprintf(stderr, "error: expected string, got: %d\n", payload_type); fprintf(stderr, "error: expected string, got: %d\n", payload_type);
@ -268,7 +299,9 @@ static char *readstr(int sock)
int32_t length; int32_t length;
if (readall(sock, &length, sizeof(length)) == -1) { *state = RS_PAYLOAD_LENGTH;
*state = readall(*state, sock, &length, sizeof(length));
if (*state == RS_ERROR) {
fprintf(stderr, "error: unable to read string length\n"); fprintf(stderr, "error: unable to read string length\n");
return NULL; return NULL;
} }
@ -278,7 +311,8 @@ static char *readstr(int sock)
return NULL; return NULL;
} }
if (readall(sock, retval, length) == -1) { *state = readall(*state, sock, retval, length);
if (*state == RS_ERROR) {
fprintf(stderr, "error: unable to read string of length: %d\n", length); fprintf(stderr, "error: unable to read string of length: %d\n", length);
free(retval); free(retval);
return NULL; return NULL;
@ -288,16 +322,18 @@ static char *readstr(int sock)
return retval; return retval;
} }
static int readint64(int sock, int64_t *i) static int readint64(recv_state_t *state, int sock, int64_t *i)
{ {
tailer_packet_payload_type_t payload_type = read_payload_type(sock); tailer_packet_payload_type_t payload_type = read_payload_type(state, sock);
if (payload_type != TPPT_INT64) { if (payload_type != TPPT_INT64) {
fprintf(stderr, "error: expected int64, got: %d\n", payload_type); fprintf(stderr, "error: expected int64, got: %d\n", payload_type);
return -1; return -1;
} }
if (readall(sock, i, sizeof(*i)) == -1) { *state = RS_PAYLOAD_CONTENT;
*state = readall(*state, sock, i, sizeof(*i));
if (*state == -1) {
fprintf(stderr, "error: unable to read int64\n"); fprintf(stderr, "error: unable to read int64\n");
return -1; return -1;
} }
@ -581,6 +617,7 @@ int poll_paths(struct list *path_list)
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
int done = 0, timeout = 0; int done = 0, timeout = 0;
recv_state_t rstate = RS_PACKET_TYPE;
list_init(&client_path_list); list_init(&client_path_list);
@ -596,7 +633,9 @@ int main(int argc, char *argv[])
if (ready_count) { if (ready_count) {
tailer_packet_type_t type; tailer_packet_type_t type;
if (readall(STDIN_FILENO, &type, sizeof(type)) == -1) { assert(rstate == RS_PACKET_TYPE);
rstate = readall(rstate, STDIN_FILENO, &type, sizeof(type));
if (rstate == RS_ERROR) {
fprintf(stderr, "info: exiting...\n"); fprintf(stderr, "info: exiting...\n");
done = 1; done = 1;
} else { } else {
@ -604,11 +643,11 @@ int main(int argc, char *argv[])
case TPT_OPEN_PATH: case TPT_OPEN_PATH:
case TPT_CLOSE_PATH: case TPT_CLOSE_PATH:
case TPT_LOAD_PREVIEW: { case TPT_LOAD_PREVIEW: {
const char *path = readstr(STDIN_FILENO); const char *path = readstr(&rstate, STDIN_FILENO);
int64_t preview_id = 0; int64_t preview_id = 0;
if (type == TPT_LOAD_PREVIEW) { if (type == TPT_LOAD_PREVIEW) {
if (readint64(STDIN_FILENO, &preview_id) == -1) { if (readint64(&rstate, STDIN_FILENO, &preview_id) == -1) {
done = 1; done = 1;
break; break;
} }
@ -616,7 +655,7 @@ int main(int argc, char *argv[])
if (path == NULL) { if (path == NULL) {
fprintf(stderr, "error: unable to get path to open\n"); fprintf(stderr, "error: unable to get path to open\n");
done = 1; done = 1;
} else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) { } else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) {
fprintf(stderr, "error: invalid open packet\n"); fprintf(stderr, "error: invalid open packet\n");
done = 1; done = 1;
} else if (type == TPT_OPEN_PATH) { } else if (type == TPT_OPEN_PATH) {
@ -773,13 +812,13 @@ int main(int argc, char *argv[])
} }
case TPT_ACK_BLOCK: case TPT_ACK_BLOCK:
case TPT_NEED_BLOCK: { case TPT_NEED_BLOCK: {
char *path = readstr(STDIN_FILENO); char *path = readstr(&rstate, STDIN_FILENO);
// fprintf(stderr, "info: block packet path: %s\n", path); // fprintf(stderr, "info: block packet path: %s\n", path);
if (path == NULL) { if (path == NULL) {
fprintf(stderr, "error: unable to get block path\n"); fprintf(stderr, "error: unable to get block path\n");
done = 1; done = 1;
} else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) { } else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) {
fprintf(stderr, "error: invalid block packet\n"); fprintf(stderr, "error: invalid block packet\n");
done = 1; done = 1;
} else { } else {