diff --git a/src/network/core/http.h b/src/network/core/http.h index e14ef8f006..78b5be87af 100644 --- a/src/network/core/http.h +++ b/src/network/core/http.h @@ -30,6 +30,15 @@ struct HTTPCallback { */ virtual void OnReceiveData(const char *data, size_t length) = 0; + /** + * Check if there is a request to cancel the transfer. + * + * @return true iff the connection is cancelled. + * @note Cancellations are never instant, and can take a bit of time to be processed. + * The object needs to remain valid until the OnFailure() callback is called. + */ + virtual bool IsCancelled() const = 0; + /** Silentium */ virtual ~HTTPCallback() {} }; diff --git a/src/network/core/http_curl.cpp b/src/network/core/http_curl.cpp index 0694afeac7..3781eb2d4f 100644 --- a/src/network/core/http_curl.cpp +++ b/src/network/core/http_curl.cpp @@ -167,8 +167,10 @@ void HttpThread() * do about this. */ curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, +[](void *userdata, curl_off_t dltotal, curl_off_t dlnow, curl_off_t ultotal, curl_off_t ulnow) -> int { - return _http_thread_exit ? 1 : 0; + const HTTPCallback *callback = static_cast(userdata); + return (callback->IsCancelled() || _http_thread_exit) ? 1 : 0; }); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, request->callback); /* Perform the request. */ CURLcode res = curl_easy_perform(curl); @@ -177,7 +179,7 @@ void HttpThread() Debug(net, 1, "HTTP request succeeded"); request->callback->OnReceiveData(nullptr, 0); } else { - Debug(net, 0, "HTTP request failed: {}", curl_easy_strerror(res)); + Debug(net, (request->callback->IsCancelled() || _http_thread_exit) ? 1 : 0, "HTTP request failed: {}", curl_easy_strerror(res)); request->callback->OnFailure(); } } diff --git a/src/network/core/http_winhttp.cpp b/src/network/core/http_winhttp.cpp index e689ee0050..9fc28d62cb 100644 --- a/src/network/core/http_winhttp.cpp +++ b/src/network/core/http_winhttp.cpp @@ -29,10 +29,10 @@ private: HTTPCallback *callback; ///< Callback to send data back on. const std::string data; ///< Data to send, if any. - HINTERNET connection = nullptr; ///< Current connection object. - HINTERNET request = nullptr; ///< Current request object. - bool finished = false; ///< Whether we are finished with the request. - int depth = 0; ///< Current redirect depth we are in. + HINTERNET connection = nullptr; ///< Current connection object. + HINTERNET request = nullptr; ///< Current request object. + std::atomic finished = false; ///< Whether we are finished with the request. + int depth = 0; ///< Current redirect depth we are in. public: NetworkHTTPRequest(const std::wstring &uri, HTTPCallback *callback, const std::string &data); @@ -88,6 +88,8 @@ static std::string GetLastErrorAsString() */ void NetworkHTTPRequest::WinHttpCallback(DWORD code, void *info, DWORD length) { + if (this->finished) return; + switch (code) { case WINHTTP_CALLBACK_STATUS_RESOLVING_NAME: case WINHTTP_CALLBACK_STATUS_NAME_RESOLVED: @@ -108,8 +110,8 @@ void NetworkHTTPRequest::WinHttpCallback(DWORD code, void *info, DWORD length) /* Make sure we are not in a redirect loop. */ if (this->depth++ > 5) { Debug(net, 0, "HTTP request failed: too many redirects"); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); return; } break; @@ -130,8 +132,8 @@ void NetworkHTTPRequest::WinHttpCallback(DWORD code, void *info, DWORD length) /* If there is any error, we simply abort the request. */ if (status_code >= 400) { Debug(net, 0, "HTTP request failed: status-code {}", status_code); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); return; } @@ -171,20 +173,22 @@ void NetworkHTTPRequest::WinHttpCallback(DWORD code, void *info, DWORD length) case WINHTTP_CALLBACK_STATUS_SECURE_FAILURE: case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR: Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); break; default: Debug(net, 0, "HTTP request failed: unexepected callback code 0x{:x}", code); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); return; } } static void CALLBACK StaticWinHttpCallback(HINTERNET handle, DWORD_PTR context, DWORD code, void *info, DWORD length) { + if (context == 0) return; + NetworkHTTPRequest *request = (NetworkHTTPRequest *)context; request->WinHttpCallback(code, info, length); } @@ -219,8 +223,8 @@ void NetworkHTTPRequest::Connect() this->connection = WinHttpConnect(_winhttp_session, url_components.lpszHostName, url_components.nPort, 0); if (this->connection == nullptr) { Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); return; } @@ -229,8 +233,8 @@ void NetworkHTTPRequest::Connect() WinHttpCloseHandle(this->connection); Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); - this->callback->OnFailure(); this->finished = true; + this->callback->OnFailure(); return; } @@ -249,6 +253,13 @@ void NetworkHTTPRequest::Connect() */ bool NetworkHTTPRequest::Receive() { + if (this->callback->IsCancelled()) { + Debug(net, 1, "HTTP request failed: cancelled by user"); + this->finished = true; + this->callback->OnFailure(); + return true; + } + return this->finished; } diff --git a/src/network/network_content.cpp b/src/network/network_content.cpp index 57f022dd1a..6ac2cbbd10 100644 --- a/src/network/network_content.cpp +++ b/src/network/network_content.cpp @@ -324,6 +324,8 @@ void ClientNetworkContentSocketHandler::DownloadSelectedContent(uint &files, uin /* If there's nothing to download, do nothing. */ if (files == 0) return; + this->isCancelled = false; + if (_settings_client.network.no_http_content_downloads || fallback) { this->DownloadSelectedContentFallback(content); } else { @@ -574,13 +576,14 @@ void ClientNetworkContentSocketHandler::AfterDownload() } } +bool ClientNetworkContentSocketHandler::IsCancelled() const +{ + return this->isCancelled; +} + /* Also called to just clean up the mess. */ void ClientNetworkContentSocketHandler::OnFailure() { - /* If we fail, download the rest via the 'old' system. */ - uint files, bytes; - this->DownloadSelectedContent(files, bytes, true); - this->http_response.clear(); this->http_response.shrink_to_fit(); this->http_response_index = -2; @@ -591,6 +594,13 @@ void ClientNetworkContentSocketHandler::OnFailure() fclose(this->curFile); this->curFile = nullptr; } + + /* If we fail, download the rest via the 'old' system. */ + if (!this->isCancelled) { + uint files, bytes; + + this->DownloadSelectedContent(files, bytes, true); + } } void ClientNetworkContentSocketHandler::OnReceiveData(const char *data, size_t length) @@ -726,7 +736,8 @@ ClientNetworkContentSocketHandler::ClientNetworkContentSocketHandler() : http_response_index(-2), curFile(nullptr), curInfo(nullptr), - isConnecting(false) + isConnecting(false), + isCancelled(false) { this->lastActivity = std::chrono::steady_clock::now(); } @@ -772,7 +783,10 @@ public: void ClientNetworkContentSocketHandler::Connect() { if (this->sock != INVALID_SOCKET || this->isConnecting) return; + + this->isCancelled = false; this->isConnecting = true; + new NetworkContentConnecter(NetworkContentServerConnectionString()); } @@ -781,6 +795,7 @@ void ClientNetworkContentSocketHandler::Connect() */ NetworkRecvStatus ClientNetworkContentSocketHandler::CloseConnection(bool error) { + this->isCancelled = true; NetworkContentSocketHandler::CloseConnection(); if (this->sock == INVALID_SOCKET) return NETWORK_RECV_STATUS_OKAY; diff --git a/src/network/network_content.h b/src/network/network_content.h index 21d324cab7..8a2877f904 100644 --- a/src/network/network_content.h +++ b/src/network/network_content.h @@ -76,6 +76,7 @@ protected: FILE *curFile; ///< Currently downloaded file ContentInfo *curInfo; ///< Information about the currently downloaded file bool isConnecting; ///< Whether we're connecting + bool isCancelled; ///< Whether the download has been cancelled std::chrono::steady_clock::time_point lastActivity; ///< The last time there was network activity friend class NetworkContentConnecter; @@ -94,6 +95,7 @@ protected: void OnFailure() override; void OnReceiveData(const char *data, size_t length) override; + bool IsCancelled() const override; bool BeforeDownload(); void AfterDownload();