diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp index fc19439e00..e53566c0bb 100644 --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -14,6 +14,8 @@ #include "../../safeguards.h" +static const int DEFAULT_CONNECT_TIMEOUT_SECONDS = 3; ///< Allow connect() three seconds to connect. + /** * Get the hostname; in case it wasn't given the * IPv4 dotted representation is given. @@ -322,23 +324,47 @@ static SOCKET ConnectLoopProc(addrinfo *runp) if (!SetNoDelay(sock)) DEBUG(net, 1, "[%s] setting TCP_NODELAY failed", type); + if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type); + int err = connect(sock, runp->ai_addr, (int)runp->ai_addrlen); -#ifdef __EMSCRIPTEN__ - /* Emscripten is asynchronous, and as such a connect() is still in - * progress by the time the call returns. */ - if (err != 0 && errno != EINPROGRESS) -#else - if (err != 0) -#endif - { - DEBUG(net, 1, "[%s] could not connect %s socket: %s", type, family, NetworkGetLastErrorString()); + if (err != 0 && NetworkGetLastError() != EINPROGRESS) { + DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address, family, NetworkGetLastErrorString()); closesocket(sock); return INVALID_SOCKET; } - /* Connection succeeded */ - if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type); + fd_set write_fd; + struct timeval tv; + + FD_ZERO(&write_fd); + FD_SET(sock, &write_fd); + + /* Wait for connect() to either connect, timeout or fail. */ + tv.tv_usec = 0; + tv.tv_sec = DEFAULT_CONNECT_TIMEOUT_SECONDS; + int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv); + if (n < 0) { + DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address, NetworkGetLastErrorString()); + closesocket(sock); + return INVALID_SOCKET; + } + + /* If no fd is selected, the timeout has been reached. */ + if (n == 0) { + DEBUG(net, 1, "[%s] timed out while connecting to %s", type, address); + closesocket(sock); + return INVALID_SOCKET; + } + + /* Retrieve last error, if any, on the socket. */ + err = GetSocketError(sock); + if (err != 0) { + DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address, NetworkGetErrorString(err)); + closesocket(sock); + return INVALID_SOCKET; + } + /* Connection succeeded. */ DEBUG(net, 1, "[%s] connected to %s", type, address); return sock; diff --git a/src/network/core/os_abstraction.h b/src/network/core/os_abstraction.h index 7af3fd163e..9bd0e321f7 100644 --- a/src/network/core/os_abstraction.h +++ b/src/network/core/os_abstraction.h @@ -33,6 +33,8 @@ #define EWOULDBLOCK WSAEWOULDBLOCK #undef ECONNRESET #define ECONNRESET WSAECONNRESET +#undef EINPROGRESS +#define EINPROGRESS WSAEWOULDBLOCK const char *NetworkGetErrorString(int error); @@ -230,6 +232,20 @@ static inline bool SetNoDelay(SOCKET d) #endif } +/** + * Get the error from a socket, if any. + * @param d The socket to get the error from. + * @return The errno on the socket. + */ +static inline int GetSocketError(SOCKET d) +{ + int err; + socklen_t len = sizeof(err); + getsockopt(d, SOL_SOCKET, SO_ERROR, (char *)&err, &len); + + return err; +} + /* Make sure these structures have the size we expect them to be */ static_assert(sizeof(in_addr) == 4); ///< IPv4 addresses should be 4 bytes. static_assert(sizeof(in6_addr) == 16); ///< IPv6 addresses should be 16 bytes.