[remote] try to enforce protocol state

This commit is contained in:
Timothy Stack 2021-05-09 21:44:31 -07:00
parent 681f771bb7
commit 1849609e07

View File

@ -219,11 +219,24 @@ void set_client_path_state_error(struct client_path_state *cps, const char *op)
delete_client_path_list(&cps->cps_children);
}
static int readall(int sock, void *buf, size_t len)
typedef enum {
RS_ERROR,
RS_PACKET_TYPE,
RS_PAYLOAD_TYPE,
RS_PAYLOAD,
RS_PAYLOAD_LENGTH,
RS_PAYLOAD_CONTENT,
} recv_state_t;
static recv_state_t readall(recv_state_t state, int sock, void *buf, size_t len)
{
char *cbuf = (char *) buf;
off_t offset = 0;
if (state == RS_ERROR) {
return RS_ERROR;
}
while (len > 0) {
ssize_t rc = read(sock, &cbuf[offset], len);
@ -233,12 +246,12 @@ static int readall(int sock, void *buf, size_t len)
case EINTR:
break;
default:
return -1;
return RS_ERROR;
}
}
else if (rc == 0) {
errno = EIO;
return -1;
return RS_ERROR;
}
else {
len -= rc;
@ -246,20 +259,38 @@ static int readall(int sock, void *buf, size_t len)
}
}
return 0;
switch (state) {
case RS_PACKET_TYPE:
return RS_PAYLOAD_TYPE;
case RS_PAYLOAD_TYPE:
return RS_PAYLOAD;
case RS_PAYLOAD_LENGTH:
return RS_PAYLOAD_CONTENT;
case RS_PAYLOAD_CONTENT:
return RS_PAYLOAD_TYPE;
default:
return RS_ERROR;
}
}
static tailer_packet_payload_type_t read_payload_type(int sock)
static tailer_packet_payload_type_t read_payload_type(recv_state_t *state, int sock)
{
tailer_packet_payload_type_t retval = TPPT_DONE;
readall(sock, &retval, sizeof(retval));
assert(*state == RS_PAYLOAD_TYPE);
*state = readall(*state, sock, &retval, sizeof(retval));
if (*state != RS_ERROR && retval == TPPT_DONE) {
*state = RS_PACKET_TYPE;
}
return retval;
}
static char *readstr(int sock)
static char *readstr(recv_state_t *state, int sock)
{
tailer_packet_payload_type_t payload_type = read_payload_type(sock);
assert(*state == RS_PAYLOAD_TYPE);
tailer_packet_payload_type_t payload_type = read_payload_type(state, sock);
if (payload_type != TPPT_STRING) {
fprintf(stderr, "error: expected string, got: %d\n", payload_type);
@ -268,7 +299,9 @@ static char *readstr(int sock)
int32_t length;
if (readall(sock, &length, sizeof(length)) == -1) {
*state = RS_PAYLOAD_LENGTH;
*state = readall(*state, sock, &length, sizeof(length));
if (*state == RS_ERROR) {
fprintf(stderr, "error: unable to read string length\n");
return NULL;
}
@ -278,7 +311,8 @@ static char *readstr(int sock)
return NULL;
}
if (readall(sock, retval, length) == -1) {
*state = readall(*state, sock, retval, length);
if (*state == RS_ERROR) {
fprintf(stderr, "error: unable to read string of length: %d\n", length);
free(retval);
return NULL;
@ -288,16 +322,18 @@ static char *readstr(int sock)
return retval;
}
static int readint64(int sock, int64_t *i)
static int readint64(recv_state_t *state, int sock, int64_t *i)
{
tailer_packet_payload_type_t payload_type = read_payload_type(sock);
tailer_packet_payload_type_t payload_type = read_payload_type(state, sock);
if (payload_type != TPPT_INT64) {
fprintf(stderr, "error: expected int64, got: %d\n", payload_type);
return -1;
}
if (readall(sock, i, sizeof(*i)) == -1) {
*state = RS_PAYLOAD_CONTENT;
*state = readall(*state, sock, i, sizeof(*i));
if (*state == -1) {
fprintf(stderr, "error: unable to read int64\n");
return -1;
}
@ -581,6 +617,7 @@ int poll_paths(struct list *path_list)
int main(int argc, char *argv[])
{
int done = 0, timeout = 0;
recv_state_t rstate = RS_PACKET_TYPE;
list_init(&client_path_list);
@ -596,7 +633,9 @@ int main(int argc, char *argv[])
if (ready_count) {
tailer_packet_type_t type;
if (readall(STDIN_FILENO, &type, sizeof(type)) == -1) {
assert(rstate == RS_PACKET_TYPE);
rstate = readall(rstate, STDIN_FILENO, &type, sizeof(type));
if (rstate == RS_ERROR) {
fprintf(stderr, "info: exiting...\n");
done = 1;
} else {
@ -604,11 +643,11 @@ int main(int argc, char *argv[])
case TPT_OPEN_PATH:
case TPT_CLOSE_PATH:
case TPT_LOAD_PREVIEW: {
const char *path = readstr(STDIN_FILENO);
const char *path = readstr(&rstate, STDIN_FILENO);
int64_t preview_id = 0;
if (type == TPT_LOAD_PREVIEW) {
if (readint64(STDIN_FILENO, &preview_id) == -1) {
if (readint64(&rstate, STDIN_FILENO, &preview_id) == -1) {
done = 1;
break;
}
@ -616,7 +655,7 @@ int main(int argc, char *argv[])
if (path == NULL) {
fprintf(stderr, "error: unable to get path to open\n");
done = 1;
} else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) {
} else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) {
fprintf(stderr, "error: invalid open packet\n");
done = 1;
} else if (type == TPT_OPEN_PATH) {
@ -773,13 +812,13 @@ int main(int argc, char *argv[])
}
case TPT_ACK_BLOCK:
case TPT_NEED_BLOCK: {
char *path = readstr(STDIN_FILENO);
char *path = readstr(&rstate, STDIN_FILENO);
// fprintf(stderr, "info: block packet path: %s\n", path);
if (path == NULL) {
fprintf(stderr, "error: unable to get block path\n");
done = 1;
} else if (read_payload_type(STDIN_FILENO) != TPPT_DONE) {
} else if (read_payload_type(&rstate, STDIN_FILENO) != TPPT_DONE) {
fprintf(stderr, "error: invalid block packet\n");
done = 1;
} else {