From a2051bad503618f37e941aca3e4a5d53af1b0fbe Mon Sep 17 00:00:00 2001 From: Rubidium Date: Sun, 18 Apr 2021 09:26:06 +0200 Subject: [PATCH] Codechange: move logic whether there is enough space in a packet to write data into the Packet --- src/network/core/packet.cpp | 41 ++++++++++++++++++++++------------ src/network/core/packet.h | 3 ++- src/network/network_admin.cpp | 2 +- src/network/network_client.cpp | 2 +- src/network/network_udp.cpp | 12 +++++----- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index 4eb0e929ee..54f5a79e16 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -68,6 +68,16 @@ void Packet::PrepareToSend() this->pos = 0; // We start reading from here } +/** + * Is it safe to write to the packet, i.e. didn't we run over the buffer? + * @param bytes_to_write The amount of bytes we want to try to write. + * @return True iff the given amount of bytes can be written to the packet. + */ +bool Packet::CanWriteToPacket(size_t bytes_to_write) +{ + return this->size + bytes_to_write < SEND_MTU; +} + /* * The next couple of functions make sure we can send * uint8, uint16, uint32 and uint64 endian-safe @@ -95,7 +105,7 @@ void Packet::Send_bool(bool data) */ void Packet::Send_uint8(uint8 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = data; } @@ -105,7 +115,7 @@ void Packet::Send_uint8(uint8 data) */ void Packet::Send_uint16(uint16 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); } @@ -116,7 +126,7 @@ void Packet::Send_uint16(uint16 data) */ void Packet::Send_uint32(uint32 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); this->buffer[this->size++] = GB(data, 16, 8); @@ -129,7 +139,7 @@ void Packet::Send_uint32(uint32 data) */ void Packet::Send_uint64(uint64 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); this->buffer[this->size++] = GB(data, 16, 8); @@ -148,8 +158,8 @@ void Packet::Send_uint64(uint64 data) void Packet::Send_string(const char *data) { assert(data != nullptr); - /* The <= *is* valid due to the fact that we are comparing sizes and not the index. */ - assert(this->size + strlen(data) + 1 <= SEND_MTU); + /* Length of the string + 1 for the '\0' termination. */ + assert(this->CanWriteToPacket(strlen(data) + 1)); while ((this->buffer[this->size++] = *data++) != '\0') {} } @@ -162,18 +172,21 @@ void Packet::Send_string(const char *data) /** - * Is it safe to read from the packet, i.e. didn't we run over the buffer ? - * @param bytes_to_read The amount of bytes we want to try to read. + * Is it safe to read from the packet, i.e. didn't we run over the buffer? + * In case \c close_connection is true, the connection will be closed when one would + * overrun the buffer. When it is false, the connection remains untouched. + * @param bytes_to_read The amount of bytes we want to try to read. + * @param close_connection Whether to close the connection if one cannot read that amount. * @return True if that is safe, otherwise false. */ -bool Packet::CanReadFromPacket(uint bytes_to_read) +bool Packet::CanReadFromPacket(size_t bytes_to_read, bool close_connection) { /* Don't allow reading from a quit client/client who send bad data */ if (this->cs->HasClientQuit()) return false; /* Check if variable is within packet-size */ if (this->pos + bytes_to_read > this->size) { - this->cs->NetworkSocketHandler::CloseConnection(); + if (close_connection) this->cs->NetworkSocketHandler::CloseConnection(); return false; } @@ -235,7 +248,7 @@ uint8 Packet::Recv_uint8() { uint8 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = this->buffer[this->pos++]; return n; @@ -249,7 +262,7 @@ uint16 Packet::Recv_uint16() { uint16 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint16)this->buffer[this->pos++]; n += (uint16)this->buffer[this->pos++] << 8; @@ -264,7 +277,7 @@ uint32 Packet::Recv_uint32() { uint32 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint32)this->buffer[this->pos++]; n += (uint32)this->buffer[this->pos++] << 8; @@ -281,7 +294,7 @@ uint64 Packet::Recv_uint64() { uint64 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint64)this->buffer[this->pos++]; n += (uint64)this->buffer[this->pos++] << 8; diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 6e5c5509ce..901d3f593b 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -63,6 +63,7 @@ public: /* Sending/writing of packets */ void PrepareToSend(); + bool CanWriteToPacket(size_t bytes_to_write); void Send_bool (bool data); void Send_uint8 (uint8 data); void Send_uint16(uint16 data); @@ -75,7 +76,7 @@ public: bool ParsePacketSize(); void PrepareToRead(); - bool CanReadFromPacket (uint bytes_to_read); + bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); bool Recv_bool (); uint8 Recv_uint8 (); uint16 Recv_uint16(); diff --git a/src/network/network_admin.cpp b/src/network/network_admin.cpp index fa97b7e578..057ad59883 100644 --- a/src/network/network_admin.cpp +++ b/src/network/network_admin.cpp @@ -613,7 +613,7 @@ NetworkRecvStatus ServerNetworkAdminSocketHandler::SendCmdNames() /* Should SEND_MTU be exceeded, start a new packet * (magic 5: 1 bool "more data" and one uint16 "command id", one * byte for string '\0' termination and 1 bool "no more data" */ - if (p->size + strlen(cmdname) + 5 >= SEND_MTU) { + if (p->CanWriteToPacket(strlen(cmdname) + 5)) { p->Send_bool(false); this->SendPacket(p); diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index 72f69f99f7..10b4fd1411 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -933,7 +933,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_FRAME(Packet *p } #endif /* Receive the token. */ - if (p->pos != p->size) this->token = p->Recv_uint8(); + if (p->CanReadFromPacket(sizeof(uint8))) this->token = p->Recv_uint8(); DEBUG(net, 5, "Received FRAME %d", _frame_counter_server); diff --git a/src/network/network_udp.cpp b/src/network/network_udp.cpp index 46a21fc87d..aa34515bdd 100644 --- a/src/network/network_udp.cpp +++ b/src/network/network_udp.cpp @@ -220,23 +220,23 @@ void ServerNetworkUDPSocketHandler::Receive_CLIENT_DETAIL_INFO(Packet *p, Networ static const uint MIN_CI_SIZE = 54; uint max_cname_length = NETWORK_COMPANY_NAME_LENGTH; - if (Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH) >= (uint)SEND_MTU - packet.size) { + if (!packet.CanWriteToPacket(Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH))) { /* Assume we can at least put the company information in the packets. */ - assert(Company::GetNumItems() * MIN_CI_SIZE < (uint)SEND_MTU - packet.size); + assert(packet.CanWriteToPacket(Company::GetNumItems() * MIN_CI_SIZE)); /* At this moment the company names might not fit in the * packet. Check whether that is really the case. */ for (;;) { - int free = SEND_MTU - packet.size; + size_t required = 0; for (const Company *company : Company::Iterate()) { char company_name[NETWORK_COMPANY_NAME_LENGTH]; SetDParam(0, company->index); GetString(company_name, STR_COMPANY_NAME, company_name + max_cname_length - 1); - free -= MIN_CI_SIZE; - free -= (int)strlen(company_name); + required += MIN_CI_SIZE; + required += strlen(company_name); } - if (free >= 0) break; + if (packet.CanWriteToPacket(required)) break; /* Try again, with slightly shorter strings. */ assert(max_cname_length > 0);