From aa0f54fa0708355497bd3ce5f684a80dfd806afb Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 10 Mar 2021 11:11:42 -0400 Subject: [PATCH] WIP plainquic tunnels --- docs/tcp-over-quic.md | 22 + llarp/CMakeLists.txt | 18 +- llarp/quic/address.cpp | 57 ++ llarp/quic/address.hpp | 126 ++++ llarp/quic/client.cpp | 99 +++ llarp/quic/client.hpp | 31 + llarp/quic/connection.cpp | 1214 +++++++++++++++++++++++++++++++ llarp/quic/connection.hpp | 311 ++++++++ llarp/quic/endpoint.cpp | 526 +++++++++++++ llarp/quic/endpoint.hpp | 241 ++++++ llarp/quic/io_result.hpp | 38 + llarp/quic/log.cpp | 45 ++ llarp/quic/log.hpp | 146 ++++ llarp/quic/null_crypto.cpp | 93 +++ llarp/quic/null_crypto.hpp | 44 ++ llarp/quic/packet.hpp | 15 + llarp/quic/random.hpp | 37 + llarp/quic/server.cpp | 117 +++ llarp/quic/server.hpp | 40 + llarp/quic/stream.cpp | 336 +++++++++ llarp/quic/stream.hpp | 343 +++++++++ llarp/quic/tunnel.cpp | 111 +++ llarp/quic/tunnel.hpp | 54 ++ llarp/quic/tunnel_client.cpp | 139 ++++ llarp/quic/tunnel_server.cpp | 174 +++++ llarp/quic/tunnel_server.hpp | 80 ++ llarp/service/protocol_type.hpp | 1 + 27 files changed, 4456 insertions(+), 2 deletions(-) create mode 100644 llarp/quic/address.cpp create mode 100644 llarp/quic/address.hpp create mode 100644 llarp/quic/client.cpp create mode 100644 llarp/quic/client.hpp create mode 100644 llarp/quic/connection.cpp create mode 100644 llarp/quic/connection.hpp create mode 100644 llarp/quic/endpoint.cpp create mode 100644 llarp/quic/endpoint.hpp create mode 100644 llarp/quic/io_result.hpp create mode 100644 llarp/quic/log.cpp create mode 100644 llarp/quic/log.hpp create mode 100644 llarp/quic/null_crypto.cpp create mode 100644 llarp/quic/null_crypto.hpp create mode 100644 llarp/quic/packet.hpp create mode 100644 llarp/quic/random.hpp create mode 100644 llarp/quic/server.cpp create mode 100644 llarp/quic/server.hpp create mode 100644 llarp/quic/stream.cpp create mode 100644 llarp/quic/stream.hpp create mode 100644 llarp/quic/tunnel.cpp create mode 100644 llarp/quic/tunnel.hpp create mode 100644 llarp/quic/tunnel_client.cpp create mode 100644 llarp/quic/tunnel_server.cpp create mode 100644 llarp/quic/tunnel_server.hpp diff --git a/docs/tcp-over-quic.md b/docs/tcp-over-quic.md index 7412759ce..15fcab82c 100644 --- a/docs/tcp-over-quic.md +++ b/docs/tcp-over-quic.md @@ -214,3 +214,25 @@ not already running), deliver the packets into it, and it tunnels incoming strea connections to the primary lokinet IP (using the IP mapped to the lokinet endpoint as the source address). + +TODO: +- Add quic protocol type to llarp/service/protocol_types.hpp +- Convert stuff in plainquic code to use lokinet structures (e.g. logging, address encapsulation) +- Add handler for QUIC packets to llarp/handlers/tun.cpp that see that protocol type and forward the + packet off to the quic server to handle. +- Get at the uvw event loop from the quic code so that we can put the plainquic stuff onto it rather + than spinning up its own event loop. I was thinking about something like: + `virtual std::shared_ptr get_uvw_loop() { return nullptr; }` in ev.h, and an override that + returns the uvw event loop in the ev_libuv.h subclass (the type erasure through the shared_ptr + means ev.h doesn't have to depend on any uvw.h headers). Then the quic code can just do something + like: + auto uv_loop = std::static_pointer_cast(ev->get_uvw_loop()); + if (not uv_loop) { die("horribly"); } +- convert the crap in the `main` functions copied from plainquic test code to exposed library calls. +- decide whether we start up a quic server and/or client on demand, or just always start it. + + +Outgoing conns: +- Add "supported protocols" item to introset and (for liblokinet) leave off IPv4/v6 flags, but add + quic protocol flag. + diff --git a/llarp/CMakeLists.txt b/llarp/CMakeLists.txt index f2d9493d5..21504ed4c 100644 --- a/llarp/CMakeLists.txt +++ b/llarp/CMakeLists.txt @@ -85,6 +85,20 @@ if(CMAKE_SYSTEM_NAME MATCHES "FreeBSD") target_include_directories(lokinet-platform SYSTEM PUBLIC /usr/local/include) endif() +add_library(lokinet-quic + quic/address.cpp + quic/client.cpp + quic/connection.cpp + quic/endpoint.cpp + quic/null_crypto.cpp + quic/server.cpp + quic/stream.cpp + quic/tunnel.cpp + quic/tunnel_client.cpp + quic/tunnel_server.cpp +) +target_link_libraries(lokinet-quic PRIVATE lokinet-platform ngtcp2) + add_library(liblokinet STATIC config/config.cpp @@ -204,7 +218,7 @@ add_library(liblokinet set_target_properties(liblokinet PROPERTIES OUTPUT_NAME lokinet) -enable_lto(lokinet-util lokinet-platform liblokinet) +enable_lto(lokinet-util lokinet-platform lokinet-quic liblokinet) if(TRACY_ROOT) target_sources(liblokinet PRIVATE ${TRACY_ROOT}/TracyClient.cpp) @@ -222,7 +236,7 @@ if(WITH_HIVE) ) endif() -target_link_libraries(liblokinet PUBLIC cxxopts lokinet-platform lokinet-util lokinet-cryptography sqlite_orm) +target_link_libraries(liblokinet PUBLIC cxxopts lokinet-platform lokinet-util lokinet-cryptography lokinet-quic sqlite_orm) target_link_libraries(liblokinet PRIVATE libunbound) diff --git a/llarp/quic/address.cpp b/llarp/quic/address.cpp new file mode 100644 index 000000000..fbf793244 --- /dev/null +++ b/llarp/quic/address.cpp @@ -0,0 +1,57 @@ +#include "address.hpp" + +extern "C" +{ +#include +} + +#include + +namespace llarp::quic +{ + using namespace std::literals; + + Address::Address(std::array ip, uint16_t port) + { + s.in.sin_family = AF_INET; + std::memcpy(&s.in.sin_addr.s_addr, ip.data(), ip.size()); + s.in.sin_port = htons(port); + a.addrlen = sizeof(s.in); + } + + Address::Address(const sockaddr_any* addr, size_t addrlen) + { + assert(addrlen == sizeof(sockaddr_in)); // FIXME: IPv6 support + std::memmove(&s, addr, addrlen); + a.addrlen = addrlen; + } + Address& + Address::operator=(const Address& addr) + { + std::memmove(&s, &addr.s, sizeof(s)); + a.addrlen = addr.a.addrlen; + return *this; + } + + std::string + Address::to_string() const + { + if (a.addrlen != sizeof(sockaddr_in)) + return "(unknown-addr)"; + char buf[INET_ADDRSTRLEN] = {0}; + inet_ntop(AF_INET, &s.in.sin_addr, buf, INET_ADDRSTRLEN); + return buf + ":"s + std::to_string(ntohs(s.in.sin_port)); + } + + std::ostream& + operator<<(std::ostream& o, const Address& a) + { + return o << a.to_string(); + } + std::ostream& + operator<<(std::ostream& o, const Path& p) + { + return o << p.local << "<-" << p.remote; + } + +} // namespace llarp::quic diff --git a/llarp/quic/address.hpp b/llarp/quic/address.hpp new file mode 100644 index 000000000..1ca8c55b0 --- /dev/null +++ b/llarp/quic/address.hpp @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +extern "C" +{ +#include +#include +} + +// FIXME: replace use of this with a llarp::SockAddr + +namespace llarp::quic +{ + union sockaddr_any + { + sockaddr_storage storage; + sockaddr sa; + sockaddr_in6 in6; + sockaddr_in in; + }; + + class Address + { + sockaddr_any s{}; + ngtcp2_addr a{0, &s.sa, nullptr}; + + public: + Address() = default; + Address(std::array ip, uint16_t port); + Address(const sockaddr_any* addr, size_t addrlen); + Address(const Address& addr) + { + *this = addr; + } + Address& + operator=(const Address& addr); + + // Implicit conversion to sockaddr* and ngtcp2_addr& so that an Address can be passed wherever + // one of those is expected. + operator sockaddr*() + { + return a.addr; + } + operator const sockaddr*() const + { + return a.addr; + } + constexpr socklen_t + sockaddr_size() const + { + return a.addrlen; + } + operator ngtcp2_addr&() + { + return a; + } + operator const ngtcp2_addr&() const + { + return a; + } + + std::string + to_string() const; + }; + + // Wraps an ngtcp2_path (which is basically just and address pair) with remote/local components. + // Implicitly convertable to a ngtcp2_path* so that this can be passed wherever a ngtcp2_path* is + // taken in the ngtcp2 API. + struct Path + { + private: + Address local_{}, remote_{}; + + public: + ngtcp2_path path{ + {local_.sockaddr_size(), local_, nullptr}, {remote_.sockaddr_size(), remote_, nullptr}}; + + // Public accessors are const: + const Address& local = local_; + const Address& remote = remote_; + + Path() = default; + Path(const Address& local, const Address& remote) : local_{local}, remote_{remote} + {} + Path(const Address& local, const sockaddr_any* remote_addr, size_t remote_len) + : local_{local}, remote_{remote_addr, remote_len} + {} + Path(const Path& p) : local_{p.local_}, remote_{p.remote_} + {} + + Path& + operator=(const Path& p) + { + local_ = p.local_; + remote_ = p.remote_; + return *this; + } + + // Equivalent to `&obj.path`, but slightly more convenient for passing into ngtcp2 functions + // taking a ngtcp2_path pointer. + operator ngtcp2_path*() + { + return &path; + } + operator const ngtcp2_path*() const + { + return &path; + } + + std::string + to_string() const; + }; + + std::ostream& + operator<<(std::ostream& o, const Address& a); + std::ostream& + operator<<(std::ostream& o, const Path& p); + +} // namespace llarp::quic diff --git a/llarp/quic/client.cpp b/llarp/quic/client.cpp new file mode 100644 index 000000000..f9ec1f7fb --- /dev/null +++ b/llarp/quic/client.cpp @@ -0,0 +1,99 @@ + +#include "client.hpp" +#include "log.hpp" + +#include + +namespace llarp::quic +{ + Client::Client( + Address remote, + std::shared_ptr loop_, + uint16_t tunnel_port, + std::optional
local_) + : Endpoint{std::move(local_), std::move(loop_)} + { + // Our UDP socket is now set up, so now we initiate contact with the remote QUIC + Path path{local, remote}; + Debug("Connecting to ", remote); + + if (tunnel_port == 0) + throw std::logic_error{"Cannot tunnel to port 0"}; + + // TODO: need timers for: + // + // - timeout (to disconnect if idle for too longer) + // + // - probably don't need for lokinet tunnel: change local addr -- attempts to re-bind the local + // socket + // + // - key_update_timer + // + // - delay_stream_timer + + auto connptr = + std::make_shared(*this, ConnectionID::random(rng), path, tunnel_port); + auto& conn = *connptr; + conns.emplace(conn.base_cid, connptr); + + /* Debug("set crypto ctx"); + + null_crypto.client_initial(conn); + + auto x = ngtcp2_conn_get_max_data_left(conn); + Debug("mdl = ", x); + */ + + conn.io_ready(); + + /* + Debug("Opening bidi stream"); + int64_t stream_id; + if (auto rv = ngtcp2_conn_open_bidi_stream(conn, &stream_id, nullptr); + rv != 0) { + Debug("Opening bidi stream failed: ", ngtcp2_strerror(rv)); + assert(rv == NGTCP2_ERR_STREAM_ID_BLOCKED); + } + else { Debug("Opening bidi stream good"); } + */ + } + + std::shared_ptr + Client::get_connection() + { + // A client only has one outgoing connection, so everything in conns should either be a + // shared_ptr or weak_ptr to that same outgoing connection so we can just use the first one. + auto it = conns.begin(); + if (it == conns.end()) + return nullptr; + if (auto* wptr = std::get_if(&it->second)) + return wptr->lock(); + return std::get(it->second); + } + + void + Client::handle_packet(const Packet& p) + { + Debug("Handling incoming client packet: ", buffer_printer{p.data}); + auto maybe_dcid = handle_packet_init(p); + if (!maybe_dcid) + return; + auto& dcid = *maybe_dcid; + + Debug("Incoming connection id ", dcid); + auto [connptr, alias] = get_conn(dcid); + if (!connptr) + { + Debug("CID is ", alias ? "expired alias" : "unknown/expired", "; dropping"); + return; + } + auto& conn = *connptr; + if (alias) + Debug("CID is alias for primary CID ", conn.base_cid); + else + Debug("CID is primary CID"); + + handle_conn_packet(conn, p); + } + +} // namespace llarp::quic diff --git a/llarp/quic/client.hpp b/llarp/quic/client.hpp new file mode 100644 index 000000000..a7e44b36a --- /dev/null +++ b/llarp/quic/client.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "endpoint.hpp" + +#include + +namespace llarp::quic +{ + class Client : public Endpoint + { + public: + // Constructs a client that establishes an outgoing connection to `remote` to tunnel packets to + // `tunnel_port` on the remote's lokinet address. `local` can be used to optionally bind to a + // local IP and/or port for the connection. + Client( + Address remote, + std::shared_ptr loop, + uint16_t tunnel_port, + std::optional
local = std::nullopt); + + // Returns a reference to the client's connection to the server. Returns a nullptr if there is + // no connection. + std::shared_ptr + get_connection(); + + private: + void + handle_packet(const Packet& p) override; + }; + +} // namespace llarp::quic diff --git a/llarp/quic/connection.cpp b/llarp/quic/connection.cpp new file mode 100644 index 000000000..42495d0f2 --- /dev/null +++ b/llarp/quic/connection.cpp @@ -0,0 +1,1214 @@ +#include "connection.hpp" +#include "client.hpp" +#include "log.hpp" +#include "server.hpp" + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace llarp::quic +{ + ConnectionID::ConnectionID(const uint8_t* cid, size_t length) + { + assert(length <= max_size()); + datalen = length; + std::memmove(data, cid, datalen); + } + + std::ostream& + operator<<(std::ostream& o, const ConnectionID& c) + { + return o << oxenmq::to_hex(c.data, c.data + c.datalen); + } + + namespace + { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" + + constexpr int FAIL = NGTCP2_ERR_CALLBACK_FAILURE; + + int + client_initial(ngtcp2_conn* conn_, void* user_data) + { + Debug("######################", __func__); + + // Initialization the connection and send our transport parameters to the server. This will + // put the connection into NGTCP2_CS_CLIENT_WAIT_HANDSHAKE state. + return static_cast(user_data)->init_client(); + } + int + recv_client_initial(ngtcp2_conn* conn_, const ngtcp2_cid* dcid, void* user_data) + { + Debug("######################", __func__); + + // New incoming connection from a client: our server connection starts out here in state + // NGTCP2_CS_SERVER_INITIAL, but we should immediately get into recv_crypto_data because the + // initial client packet should contain the client's transport parameters. + + auto& conn = *static_cast(user_data); + assert(conn_ == conn.conn.get()); + + if (0 != conn.setup_server_crypto_initial()) + return FAIL; + + return 0; + } + int + recv_crypto_data( + ngtcp2_conn* conn_, + ngtcp2_crypto_level crypto_level, + uint64_t offset, + const uint8_t* rawdata, + size_t rawdatalen, + void* user_data) + { + std::basic_string_view data{rawdata, rawdatalen}; + Debug("\e[32;1mReceiving crypto data @ level ", crypto_level, "\e[0m ", buffer_printer{data}); + + auto& conn = *static_cast(user_data); + switch (crypto_level) + { + case NGTCP2_CRYPTO_LEVEL_EARLY: + // We don't currently use or support 0rtt + Warn("Invalid EARLY crypto level"); + return FAIL; + + case NGTCP2_CRYPTO_LEVEL_INITIAL: + // "Initial" level means we are still handshaking; if we are server then we receive + // the client's transport params (sent in client_initial, above) and blast ours + // back. If we are a client then getting here means we received a response from the + // server, which is that returned server transport params. + + if (auto rv = conn.recv_initial_crypto(data); rv != 0) + return rv; + + if (ngtcp2_conn_is_server(conn)) + { + if (auto rv = conn.send_magic(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0) + return rv; + if (auto rv = conn.send_transport_params(NGTCP2_CRYPTO_LEVEL_HANDSHAKE); rv != 0) + return rv; + } + + break; + + case NGTCP2_CRYPTO_LEVEL_HANDSHAKE: + if (!ngtcp2_conn_is_server(conn)) + { + if (auto rv = conn.recv_transport_params(data); rv != 0) + return rv; + // At this stage of the protocol with TLS the client sends back TLS info so that + // the server can install our rx key; we have to send *something* back to invoke + // the server's HANDSHAKE callback (so that it knows handshake is complete) so + // sent the magic again. + if (auto rv = conn.send_magic(NGTCP2_CRYPTO_LEVEL_HANDSHAKE); rv != 0) + return rv; + } + else + { + // Check that we received the above as expected + if (data != handshake_magic) + { + Warn("Invalid handshake crypto frame from client: did not find expected magic"); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + } + + conn.complete_handshake(); + break; + + case NGTCP2_CRYPTO_LEVEL_APPLICATION: + // if (!conn.init_tx_key()) + // return FAIL; + break; + + default: + Warn("Unhandled crypto_level ", crypto_level); + return FAIL; + } + conn.io_ready(); + return 0; + } + int + encrypt( + uint8_t* dest, + const ngtcp2_crypto_aead* aead, + const ngtcp2_crypto_aead_ctx* aead_ctx, + const uint8_t* plaintext, + size_t plaintextlen, + const uint8_t* nonce, + size_t noncelen, + const uint8_t* ad, + size_t adlen) + { + Debug("######################", __func__); + Debug("Lengths: ", plaintextlen, "+", noncelen, "+", adlen); + if (dest != plaintext) + std::memmove(dest, plaintext, plaintextlen); + return 0; + } + int + decrypt( + uint8_t* dest, + const ngtcp2_crypto_aead* aead, + const ngtcp2_crypto_aead_ctx* aead_ctx, + const uint8_t* ciphertext, + size_t ciphertextlen, + const uint8_t* nonce, + size_t noncelen, + const uint8_t* ad, + size_t adlen) + { + Debug("######################", __func__); + Debug("Lengths: ", ciphertextlen, "+", noncelen, "+", adlen); + if (dest != ciphertext) + std::memmove(dest, ciphertext, ciphertextlen); + return 0; + } + int + hp_mask( + uint8_t* dest, + const ngtcp2_crypto_cipher* hp, + const ngtcp2_crypto_cipher_ctx* hp_ctx, + const uint8_t* sample) + { + Debug("######################", __func__); + memset(dest, 0, NGTCP2_HP_MASKLEN); + return 0; + } + int + recv_stream_data( + ngtcp2_conn* conn, + uint32_t flags, + int64_t stream_id, + uint64_t offset, + const uint8_t* data, + size_t datalen, + void* user_data, + void* stream_user_data) + { + Debug("######################", __func__); + return static_cast(user_data)->stream_receive( + {stream_id}, + {reinterpret_cast(data), datalen}, + flags & NGTCP2_STREAM_DATA_FLAG_FIN); + } + + int + acked_stream_data_offset( + ngtcp2_conn* conn_, + int64_t stream_id, + uint64_t offset, + uint64_t datalen, + void* user_data, + void* stream_user_data) + { + Debug("######################", __func__); + Debug("Ack [", offset, ",", offset + datalen, ")"); + return static_cast(user_data)->stream_ack({stream_id}, datalen); + } + + int + stream_open(ngtcp2_conn* conn, int64_t stream_id, void* user_data) + { + Debug("######################", __func__); + return static_cast(user_data)->stream_opened({stream_id}); + } + int + stream_reset_cb( + ngtcp2_conn* conn, + int64_t stream_id, + uint64_t final_size, + uint64_t app_error_code, + void* user_data, + void* stream_user_data) + { + Debug("######################", __func__); + return static_cast(user_data)->stream_reset({stream_id}, app_error_code); + } + + // (client only) + int + recv_retry(ngtcp2_conn* conn, const ngtcp2_pkt_hd* hd, void* user_data) + { + Debug("######################", __func__); + Error("FIXME UNIMPLEMENTED ", __func__); + // FIXME + return 0; + } + int + rand( + uint8_t* dest, + size_t destlen, + const ngtcp2_rand_ctx* rand_ctx, + [[maybe_unused]] ngtcp2_rand_usage usage) + { + Debug("######################", __func__); + auto& rng = *static_cast(rand_ctx->native_handle); + random_bytes(dest, destlen, rng); + return 0; + } + int + get_new_connection_id( + ngtcp2_conn* conn_, ngtcp2_cid* cid_, uint8_t* token, size_t cidlen, void* user_data) + { + Debug("######################", __func__); + + auto& conn = *static_cast(user_data); + auto cid = conn.make_alias_id(cidlen); + assert(cid.datalen == cidlen); + *cid_ = cid; + + conn.endpoint.make_stateless_reset_token(cid, token); + Debug( + "make stateless reset token ", + oxenmq::to_hex(token, token + NGTCP2_STATELESS_RESET_TOKENLEN)); + + return 0; + } + int + remove_connection_id(ngtcp2_conn* conn, const ngtcp2_cid* cid, void* user_data) + { + Debug("######################", __func__); + Error("FIXME UNIMPLEMENTED ", __func__); + // FIXME + return 0; + } + int + update_key( + ngtcp2_conn* conn, + uint8_t* rx_secret, + uint8_t* tx_secret, + ngtcp2_crypto_aead_ctx* rx_aead_ctx, + uint8_t* rx_iv, + ngtcp2_crypto_aead_ctx* tx_aead_ctx, + uint8_t* tx_iv, + const uint8_t* current_rx_secret, + const uint8_t* current_tx_secret, + size_t secretlen, + void* user_data) + { + // This is a no-op since we don't encrypt anything in the first place + return 0; + } + /* + int recv_new_token(ngtcp2_conn* conn, const ngtcp2_vec* token, void* user_data) { + Debug("######################", __func__); + Error("FIXME UNIMPLEMENTED ", __func__); + // FIXME + return 0; + } + */ +#pragma GCC diagnostic pop + } // namespace + +#ifndef NDEBUG + extern "C" inline void + debug_logger([[maybe_unused]] void* user_data, const char* fmt, ...) + { + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fprintf(stderr, "\n"); + } +#endif + + io_result + Connection::send() + { + assert(send_buffer_size <= send_buffer.size()); + io_result rv{}; + bstring_view send_data{send_buffer.data(), send_buffer_size}; + + if (!send_data.empty()) + { + Debug("Sending packet: ", buffer_printer{send_data}); + rv = endpoint.send_packet(path.remote, send_data, send_pkt_info.ecn); + if (rv.blocked()) + { + if (!wpoll) + { + wpoll = endpoint.loop->resource(endpoint.socket_fd()); + wpoll->on([this](const auto&, auto&) { send(); }); + } + if (!wpoll_active) + { + wpoll->start(uvw::PollHandle::Event::WRITABLE); + wpoll_active = true; + } + } + else if (!rv) + { + // FIXME: disconnect here? + Warn("packet send failed: ", rv.str()); + Error("FIXME - should disconnect"); + } + else if (wpoll_active) + { + wpoll->stop(); + wpoll_active = false; + } + } + return rv; + + // We succeeded + // + // FIXME2: probably don't want to do these things *here*, because this is called from the stream + // checking code. + // + // FIXME: check and send other pending streams + // + // FIXME: schedule retransmit? + // return true; + } + + std::tuple + Connection::init() + { + io_trigger = endpoint.loop->resource(); + io_trigger->on([this](auto&, auto&) { on_io_ready(); }); + + retransmit_timer = endpoint.loop->resource(); + retransmit_timer->on([this](auto&, auto&) { + Debug("Retransmit timer fired!"); + if (auto rv = ngtcp2_conn_handle_expiry(*this, get_timestamp()); rv != 0) + { + Warn("expiry handler invocation returned an error: ", ngtcp2_strerror(rv)); + endpoint.close_connection(*this, ngtcp2_err_infer_quic_transport_error_code(rv), false); + } + else + { + flush_streams(); + } + }); + retransmit_timer->start(0ms, 0ms); + + auto result = std::tuple{}; + auto& [settings, tparams, cb] = result; + cb.recv_crypto_data = recv_crypto_data; + cb.encrypt = encrypt; + cb.decrypt = decrypt; + cb.hp_mask = hp_mask; + cb.recv_stream_data = recv_stream_data; + cb.acked_stream_data_offset = acked_stream_data_offset; + cb.stream_open = stream_open; + cb.stream_reset = stream_reset_cb; + cb.rand = rand; + cb.get_new_connection_id = get_new_connection_id; + cb.remove_connection_id = remove_connection_id; + cb.update_key = update_key; + + ngtcp2_settings_default(&settings); + +#ifndef NDEBUG + settings.log_printf = debug_logger; +#endif + settings.initial_ts = get_timestamp(); + // FIXME: IPv6 + settings.max_udp_payload_size = NGTCP2_MAX_PKTLEN_IPV4; + settings.cc_algo = NGTCP2_CC_ALGO_CUBIC; + // settings.initial_rtt = ???; # NGTCP2's default is 333ms + + ngtcp2_transport_params_default(&tparams); + + // Connection level flow control window: + tparams.initial_max_data = CONNECTION_BUFFER; + // Max send buffer for a streams (local is for streams we initiate, remote is for replying on + // streams they initiate to us): + tparams.initial_max_stream_data_bidi_local = STREAM_BUFFER; + tparams.initial_max_stream_data_bidi_remote = STREAM_BUFFER; + // Max *cumulative* streams we support on a connection: + tparams.initial_max_streams_bidi = STREAM_LIMIT; + tparams.initial_max_streams_uni = 0; + tparams.max_idle_timeout = std::chrono::nanoseconds(IDLE_TIMEOUT).count(); + tparams.active_connection_id_limit = 8; + + Debug("Done basic connection initialization"); + + return result; + } + + Connection::Connection( + Server& s, const ConnectionID& base_cid_, ngtcp2_pkt_hd& header, const Path& path) + : endpoint{s}, base_cid{base_cid_}, dest_cid{header.scid}, path{path} + { + auto [settings, tparams, cb] = init(); + + cb.recv_client_initial = recv_client_initial; + + Debug("header.type = ", +header.type); + + // ConnectionIDs are a little complicated: + // - when a client creates a new connection to us, it creates a random source connection ID + // *and* a random destination connection id. The server won't have that connection ID, of + // course, but we use it to recognize that we should try accepting it as a new connection. + // - When we talk to the client we use the random source connection ID that it generated as our + // destination connection ID. + // - We choose our own source ID, however: we *don't* use the random one the client picked for + // us. Instead we generate a random one and sent it back as *our* source connection ID in the + // reply to the client. + // - the client still needs to match up that reply with that request, and so we include the + // destination connection ID that the client generated for us in the transport parameters as + // the original_dcid: this lets the client match up the request, after which it can't promptly + // forget about it and start using the source CID that we gave it. + // + // So, in other words, the conversation goes like this: + // - Client: [SCID:clientid, DCID:randomid, TRANSPORT_PARAMS] + // - Server: [SCID:serverid, DCID:clientid TRANSPORT_PARAMS(origid=randomid)] + // + // - For the client, .base_cid={clientid} and .dest_cid={randomid} initially but gets updated to + // .dest_cid={serverid} when we hear back from the server. + // - For the server, .base_cid={serverid} and .dest_cid={clientid} + + tparams.original_dcid = header.dcid; + + Debug("original_dcid is now set to ", ConnectionID(tparams.original_dcid)); + + settings.token = header.token; + + // FIXME is this required? + random_bytes( + std::begin(tparams.stateless_reset_token), sizeof(tparams.stateless_reset_token), s.rng); + tparams.stateless_reset_token_present = 1; + + ngtcp2_conn* connptr; + Debug("server_new, path=", path); + if (auto rv = ngtcp2_conn_server_new( + &connptr, + &dest_cid, + &base_cid, + path, + header.version, + &cb, + &settings, + &tparams, + nullptr /*default mem allocator*/, + this); + rv != 0) + throw std::runtime_error{"Failed to initialize server connection: "s + ngtcp2_strerror(rv)}; + conn.reset(connptr); + + Debug("Created new server conn ", base_cid); + } + + Connection::Connection( + Client& c, const ConnectionID& scid, const Path& path, uint16_t tunnel_port) + : tunnel_port{tunnel_port} + , endpoint{c} + , base_cid{scid} + , dest_cid{ConnectionID::random(c.rng)} + , path{path} + { + auto [settings, tparams, cb] = init(); + + assert(tunnel_port != 0); + + cb.client_initial = client_initial; + cb.recv_retry = recv_retry; + // cb.extend_max_local_streams_bidi = extend_max_local_streams_bidi; + // cb.recv_new_token = recv_new_token; + + ngtcp2_conn* connptr; + + if (auto rv = ngtcp2_conn_client_new( + &connptr, + &dest_cid, + &scid, + path, + NGTCP2_PROTO_VER_V1, + &cb, + &settings, + &tparams, + nullptr, + this); + rv != 0) + throw std::runtime_error{"Failed to initialize client connection: "s + ngtcp2_strerror(rv)}; + conn.reset(connptr); + + Debug("Created new client conn ", scid); + } + + Connection::~Connection() + { + if (wpoll) + wpoll->close(); + if (io_trigger) + io_trigger->close(); + } + + void + Connection::io_ready() + { + io_trigger->send(); + } + + void + Connection::on_io_ready() + { + Debug(__func__); + flush_streams(); + Debug("done ", __func__); + } + + void + Connection::flush_streams() + { + // conn, path, pi, dest, destlen, and ts + std::optional ts; + + send_pkt_info = {}; + + auto add_stream_data = + [&](StreamID stream_id, const ngtcp2_vec* datav, size_t datalen, uint32_t flags = 0) { + std::array result; + auto& [nwrite, consumed] = result; + if (!ts) + ts = get_timestamp(); + + Debug("send_buffer size = ", send_buffer.size()); + Debug("datalen = ", datalen); + Debug("flags = ", flags); + nwrite = ngtcp2_conn_writev_stream( + conn.get(), + &path.path, + &send_pkt_info, + u8data(send_buffer), + send_buffer.size(), + &consumed, + NGTCP2_WRITE_STREAM_FLAG_MORE | flags, + stream_id.id, + datav, + datalen, + *ts); + return result; + }; + + auto send_packet = [&](auto nwrite) -> bool { + send_buffer_size = nwrite; + Debug("Sending ", send_buffer_size, "B packet"); + + // FIXME: update remote addr? ecn? + auto sent = send(); + if (sent.blocked()) + return false; // We'll get called again when the socket becomes writable + + send_buffer_size = 0; + if (!sent) + { + Warn("I/O error while trying to send packet: ", sent.str()); + // FIXME: disconnect? + return false; + } + Debug("packet away!"); + return true; + }; + + std::list strs; + for (auto& [stream_id, stream_ptr] : streams) + if (stream_ptr) + strs.push_back(stream_ptr.get()); + + // Maximum number of stream data packets to send out at once; if we reach this then we'll + // schedule another event loop call of ourselves (so that we don't starve the loop). + constexpr int max_stream_packets = 15; + int stream_packets = 0; + while (!strs.empty() && stream_packets < max_stream_packets) + { + for (auto it = strs.begin(); it != strs.end();) + { + auto& stream = **it; + auto bufs = stream.pending(); + if (stream.is_shutdown + || (bufs.empty() && !stream.is_new && !(stream.is_closing && !stream.sent_fin))) + { + it = strs.erase(it); + continue; + } + std::vector vecs; + vecs.reserve(bufs.size()); + std::transform(bufs.begin(), bufs.end(), std::back_inserter(vecs), [](const auto& buf) { + return ngtcp2_vec{const_cast(u8data(buf)), buf.size()}; + }); + +#ifndef NDEBUG + { + std::string buf_sizes; + for (auto& b : bufs) + { + if (!buf_sizes.empty()) + buf_sizes += '+'; + buf_sizes += std::to_string(b.size()); + } + Debug("Sending ", buf_sizes.empty() ? "no" : buf_sizes, " data for ", stream.id()); + } +#endif + + uint32_t extra_flags = 0; + if (stream.is_closing && !stream.sent_fin) + { + Debug("Sending FIN"); + extra_flags |= NGTCP2_WRITE_STREAM_FLAG_FIN; + stream.sent_fin = true; + } + else if (stream.is_new) + { + stream.is_new = false; + } + + auto [nwrite, consumed] = + add_stream_data(stream.id(), vecs.data(), vecs.size(), extra_flags); + Debug( + "add_stream_data for stream ", stream.id(), " returned [", nwrite, ",", consumed, "]"); + + if (nwrite > 0) + { + if (consumed >= 0) + { + Debug("consumed ", consumed, " bytes from stream ", stream.id()); + stream.wrote(consumed); + } + + Debug("Sending stream data packet"); + if (!send_packet(nwrite)) + return; + ++stream_packets; + ++it; + continue; + } + + switch (nwrite) + { + case 0: + Debug( + "Done stream writing to ", + stream.id(), + " (either stream is congested or we have nothing else to send right now)"); + assert(consumed <= 0); + break; + case NGTCP2_ERR_WRITE_MORE: + Debug( + "consumed ", consumed, " bytes from stream ", stream.id(), " and have space left"); + stream.wrote(consumed); + if (stream.unsent() > 0) + { + // We have more to send on this stream, so keep us in the queue + ++it; + continue; + } + break; + case NGTCP2_ERR_STREAM_DATA_BLOCKED: + Debug("cannot add to stream ", stream.id(), " right now: stream is blocked"); + break; + case NGTCP2_ERR_STREAM_SHUT_WR: + Debug("cannot write to ", stream.id(), ": stream is shut down"); + break; + default: + assert(consumed <= 0); + Warn("Error writing to stream ", stream.id(), ": ", ngtcp2_strerror(nwrite)); + break; + } + it = strs.erase(it); + } + } + + // Now try more with stream id -1 and no data: this takes care of things like initial handshake + // packets, and also finishes off any partially-filled packet from above. + for (;;) + { + auto [nwrite, consumed] = add_stream_data(StreamID{}, nullptr, 0); + Debug("add_stream_data for non-stream returned [", nwrite, ",", consumed, "]"); + assert(consumed <= 0); + if (nwrite == NGTCP2_ERR_WRITE_MORE) + { + Debug("Writing non-stream data, and have space left"); + continue; + } + if (nwrite < 0) + { + Warn("Error writing non-stream data: ", ngtcp2_strerror(nwrite)); + break; + } + if (nwrite == 0) + { + Debug("Nothing else to write for non-stream data for now (or we are congested)"); + ngtcp2_conn_stat cstat; + ngtcp2_conn_get_conn_stat(*this, &cstat); + Debug("Current unacked bytes in flight: ", cstat.bytes_in_flight); + break; + } + + Debug("Sending non-stream data packet"); + if (!send_packet(nwrite)) + return; + } + + schedule_retransmit(); + } + + void + Connection::schedule_retransmit() + { + auto expiry = std::chrono::nanoseconds{ngtcp2_conn_get_expiry(*this)}; + Debug("SCHEDULE RETRANSMIT exp ", expiry.count()); + if (expiry < 0ns) + { + retransmit_timer->repeat(0ms); + return; + } + auto expires_in = std::chrono::duration_cast( + expiry - get_time().time_since_epoch()); + Debug("Next retransmit in ", expires_in.count(), "ms"); + if (expires_in < 1ms) + expires_in = 1ms; + retransmit_timer->repeat(expires_in); + retransmit_timer->again(); + } + + int + Connection::stream_opened(StreamID id) + { + Debug("New stream ", id); + auto* serv = server(); + if (!serv) + { + Warn("We are a client, incoming streams are not accepted"); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + std::shared_ptr stream{new Stream{*this, id, endpoint.default_stream_buffer_size}}; + stream->stream_id = id; + bool good = true; + if (serv->stream_open_callback) + good = serv->stream_open_callback(*serv, *stream, tunnel_port); + if (!good) + { + Debug("stream_open_callback returned failure, dropping stream ", id); + ngtcp2_conn_shutdown_stream(*this, id.id, 1); + io_ready(); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + [[maybe_unused]] auto [it, ins] = streams.emplace(id, std::move(stream)); + assert(ins); + Debug("Created new incoming stream ", id); + return 0; + } + + int + Connection::stream_receive(StreamID id, const bstring_view data, bool fin) + { + auto str = get_stream(id); + if (!str->data_callback) + Debug("Dropping incoming data on stream ", str->id(), ": stream has no data callback set"); + else + { + bool good = false; + try + { + str->data_callback(*str, data); + good = true; + } + catch (const std::exception& e) + { + Warn( + "Stream ", + str->id(), + " data callback raised exception (", + e.what(), + "); closing stream with app code ", + STREAM_EXCEPTION_ERROR_CODE); + } + catch (...) + { + Warn( + "Stream ", + str->id(), + " data callback raised an unknown exception; closing stream with app code ", + STREAM_EXCEPTION_ERROR_CODE); + } + if (!good) + { + str->close(STREAM_EXCEPTION_ERROR_CODE); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + } + if (fin) + { + if (str->close_callback) + str->close_callback(*str, std::nullopt); + streams.erase(id); + io_ready(); + } + else + { + ngtcp2_conn_extend_max_stream_offset(*this, id.id, data.size()); + ngtcp2_conn_extend_max_offset(*this, data.size()); + } + return 0; + } + + int + Connection::stream_reset(StreamID id, uint64_t app_code) + { + Debug(id, " reset with code ", app_code); + auto it = streams.find(id); + if (it == streams.end()) + return NGTCP2_ERR_CALLBACK_FAILURE; + auto& stream = *it->second; + const bool was_closing = stream.is_closing; + stream.is_closing = true; + if (!was_closing && stream.close_callback) + { + Debug("Invoke stream close callback"); + stream.close_callback(stream, app_code); + } + + streams.erase(it); + return 0; + } + + int + Connection::stream_ack(StreamID id, size_t size) + { + if (auto it = streams.find(id); it != streams.end()) + { + it->second->acknowledge(size); + return 0; + } + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + Server* + Connection::server() + { + return dynamic_cast(&endpoint); + } + + Client* + Connection::client() + { + return dynamic_cast(&endpoint); + } + + int + Connection::setup_server_crypto_initial() + { + auto* s = server(); + assert(s); + s->null_crypto.server_initial(*this); + io_ready(); + return 0; + } + + ConnectionID + Connection::make_alias_id(size_t cidlen) + { + return endpoint.add_connection_id(*this, cidlen); + } + + const std::shared_ptr& + Connection::open_stream(Stream::data_callback_t data_cb, Stream::close_callback_t close_cb) + { + std::shared_ptr stream{new Stream{ + *this, std::move(data_cb), std::move(close_cb), endpoint.default_stream_buffer_size}}; + if (int rv = ngtcp2_conn_open_bidi_stream(*this, &stream->stream_id.id, stream.get()); rv != 0) + { + Warn("Creating stream failed: ", ngtcp2_strerror(rv)); + throw std::runtime_error{"Stream creation failed: "s + ngtcp2_strerror(rv)}; + } + + auto& str = streams[stream->stream_id]; + str = std::move(stream); + + return str; + } + + const std::shared_ptr& + Connection::get_stream(StreamID s) const + { + return streams.at(s); + } + + int + Connection::init_client() + { + endpoint.null_crypto.client_initial(*this); + + if (int rv = send_magic(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0) + return rv; + if (int rv = send_transport_params(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0) + return rv; + + io_ready(); + return 0; + } + + int + Connection::recv_initial_crypto(std::basic_string_view data) + { + if (data.substr(0, handshake_magic.size()) != handshake_magic) + { + Warn("Invalid initial crypto frame: did not find expected magic prefix"); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + data.remove_prefix(handshake_magic.size()); + + const bool is_server = ngtcp2_conn_is_server(*this); + if (is_server) + { + // For a server, we receive the transport parameters in the initial packet (prepended by the + // magic that we just removed): + if (auto rv = recv_transport_params(data); rv != 0) + return rv; + } + else + { + // For a client our initial crypto data should be just the magic string (the packet also + // contains transport parameters, but they are at HANDSHAKE crypto level and so will result + // in a second callback to handle them). + if (!data.empty()) + { + Warn("Invalid initial crypto frame: unexpected post-magic data found"); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + } + + endpoint.null_crypto.install_rx_handshake_key(*this); + endpoint.null_crypto.install_tx_handshake_key(*this); + if (is_server) + endpoint.null_crypto.install_tx_key(*this); + + return 0; + } + + void + Connection::complete_handshake() + { + endpoint.null_crypto.install_rx_key(*this); + if (!ngtcp2_conn_is_server(*this)) + endpoint.null_crypto.install_tx_key(*this); + ngtcp2_conn_handshake_completed(*this); + } + + // ngtcp2 doesn't expose the varint encoding, but it's fairly simple: + // 0bXXyyyyyy -- XX indicates the encoded size (00=1, 01=2, 10=4, 11=8) and the rest of the bits + // (6, 14, 30, or 62) are the number, with bytes in network order for >6-bit values. + + // Returns {value, consumed} where consumed is the number of bytes consumed, or 0 on failure. + static constexpr std::pair + decode_varint(std::basic_string_view data) + { + std::pair result = {0, 0}; + auto& [val, enc_size] = result; + if (data.empty()) + return result; + enc_size = 1 << (data[0] >> 6); // first two bits are log₂ of the length + if (data.size() < enc_size) + { + enc_size = 0; + return result; + } + val = data[0] & 0b0011'1111; + for (size_t i = 1; i < enc_size; i++) + val = (val << 8) | data[i]; + return result; + } + + // Encodes an unsigned integer in QUIC encoding format; return the bytes and the length (bytes + // beyond `length` are uninitialized). + static constexpr std::pair, uint8_t> + encode_varint(uint64_t val) + { + assert(val < (1ULL << 62)); + std::pair, uint8_t> result; + uint8_t size = val < (1ULL << 6) ? 0 : val < (1ULL << 14) ? 1 : val < (1ULL << 30) ? 2 : 3; + auto& [enc, len] = result; + len = 1 << size; + for (uint8_t i = 1; i <= len; i++) + { + enc[len - i] = val & 0xff; + val >>= 8; + } + enc[0] = (enc[0] & 0b00'111111) | (size << 6); + enc[0] |= size << 6; + return result; + } + + // We add some lokinet-specific data into the transport request and *always* as the first + // transport parameter, but we do it in a way that the parameter gets ignored by the QUIC + // protocol, which encodes as {varint[code], varint[length], data}, and requires a code value + // 31*N+27 (for integer N). Naturally we use N=42, which gives us 1329=0b10100110001 which + // encodes in QUIC as 0b01000101 0b00110001 (the first two bits of the first byte give the integer + // size, and the rest are the value in network order). + static constexpr uint64_t lokinet_transport_param_N = 42; + static constexpr auto lokinet_metadata_code_raw = + encode_varint(31 * lokinet_transport_param_N + 27); + static constexpr std::basic_string_view lokinet_metadata_code{ + lokinet_metadata_code_raw.first.data(), lokinet_metadata_code_raw.second}; + static_assert( + lokinet_metadata_code.size() == 2 && lokinet_metadata_code[0] == 0b01000101 + && lokinet_metadata_code[1] == 0b00110001); + + int + Connection::recv_transport_params(std::basic_string_view data) + { + if (data.substr(0, lokinet_metadata_code.size()) != lokinet_metadata_code) + { + Warn("transport params did not begin with expected lokinet metadata"); + return NGTCP2_ERR_TRANSPORT_PARAM; + } + auto [meta_len, meta_len_bytes] = decode_varint(data.substr(lokinet_metadata_code.size())); + if (meta_len_bytes == 0) + { + Warn("transport params lokinet metadata has truncated size"); + return NGTCP2_ERR_MALFORMED_TRANSPORT_PARAM; + } + std::string_view lokinet_metadata{ + reinterpret_cast( + data.substr(lokinet_metadata_code.size() + meta_len_bytes).data()), + meta_len}; + Debug("Received bencoded lokinet metadata: ", buffer_printer{lokinet_metadata}); + + uint16_t port; + try + { + oxenmq::bt_dict_consumer meta{lokinet_metadata}; + // '#' contains the port the client wants us to forward to + if (!meta.skip_until("#")) + { + Warn("transport params # (port) is missing but required"); + return NGTCP2_ERR_TRANSPORT_PARAM; + } + port = meta.consume_integer(); + if (port == 0) + { + Warn("transport params tunnel port (#) is invalid: 0 is not permitted"); + return NGTCP2_ERR_TRANSPORT_PARAM; + } + Debug("decoded lokinet tunnel port = ", port); + } + catch (const oxenmq::bt_deserialize_invalid& c) + { + Warn("transport params lokinet metadata is invalid: ", c.what()); + return NGTCP2_ERR_TRANSPORT_PARAM; + } + + const bool is_server = ngtcp2_conn_is_server(*this); + + if (is_server) + { + tunnel_port = port; + } + else + { + // Make sure the server reflected the proper port + if (tunnel_port != port) + { + Warn("server returned invalid port; expected ", tunnel_port, ", got ", port); + return NGTCP2_ERR_TRANSPORT_PARAM; + } + } + + ngtcp2_transport_params params; + + auto exttype = is_server ? NGTCP2_TRANSPORT_PARAMS_TYPE_CLIENT_HELLO + : NGTCP2_TRANSPORT_PARAMS_TYPE_ENCRYPTED_EXTENSIONS; + + auto rv = ngtcp2_decode_transport_params(¶ms, exttype, data.data(), data.size()); + Debug("Decode transport params ", rv == 0 ? "success" : "fail: "s + ngtcp2_strerror(rv)); + Debug("params orig dcid = ", ConnectionID(params.original_dcid)); + Debug("params init scid = ", ConnectionID(params.initial_scid)); + if (rv == 0) + { + rv = ngtcp2_conn_set_remote_transport_params(*this, ¶ms); + Debug("Set remote transport params ", rv == 0 ? "success" : "fail: "s + ngtcp2_strerror(rv)); + } + + if (rv != 0) + { + ngtcp2_conn_set_tls_error(*this, rv); + return rv; + } + + return 0; + } + + // Sends our magic string at the given level. This fixed magic string is taking the place of TLS + // parameters in full QUIC. + int + Connection::send_magic(ngtcp2_crypto_level level) + { + return ngtcp2_conn_submit_crypto_data( + *this, level, handshake_magic.data(), handshake_magic.size()); + } + + template + static void + copy_and_advance(uint8_t*& buf, const String& s) + { + static_assert(sizeof(typename String::value_type) == 1, "not a char-compatible type"); + std::memcpy(buf, s.data(), s.size()); + buf += s.size(); + } + + // Sends transport parameters. `level` is expected to be INITIAL for clients (which send the + // transport parameters in the initial packet), or HANDSHAKE for servers. + int + Connection::send_transport_params(ngtcp2_crypto_level level) + { + ngtcp2_transport_params tparams; + ngtcp2_conn_get_local_transport_params(*this, &tparams); + + assert(conn_buffer.empty()); + static_assert(NGTCP2_MAX_PKTLEN_IPV4 > NGTCP2_MAX_PKTLEN_IPV6); + conn_buffer.resize(NGTCP2_MAX_PKTLEN_IPV4); + + auto* buf = u8data(conn_buffer); + auto* bufend = buf + conn_buffer.size(); + { + // Send our first parameter, the lokinet metadata, in a QUIC-compatible way (by using a + // reserved field code that QUIC parsers must ignore); currently we only include the port in + // here (from the client to tell the server what it's trying to reach, and reflected from + // the server for the client to verify). + std::string lokinet_metadata = bt_serialize(oxenmq::bt_dict{ + {"#", tunnel_port}, + }); + copy_and_advance(buf, lokinet_metadata_code); + auto [bytes, size] = encode_varint(lokinet_metadata.size()); + copy_and_advance(buf, std::basic_string_view{bytes.data(), size}); + copy_and_advance(buf, lokinet_metadata); + assert(buf < bufend); + } + + const bool is_server = ngtcp2_conn_is_server(*this); + auto exttype = is_server ? NGTCP2_TRANSPORT_PARAMS_TYPE_ENCRYPTED_EXTENSIONS + : NGTCP2_TRANSPORT_PARAMS_TYPE_CLIENT_HELLO; + + if (ngtcp2_ssize nwrite = ngtcp2_encode_transport_params(buf, bufend - buf, exttype, &tparams); + nwrite >= 0) + { + assert(nwrite > 0); + conn_buffer.resize(buf - u8data(conn_buffer) + nwrite); + } + else + { + conn_buffer.clear(); + return nwrite; + } + Debug("encoded transport params: ", buffer_printer{conn_buffer}); + return ngtcp2_conn_submit_crypto_data(*this, level, u8data(conn_buffer), conn_buffer.size()); + } + +} // namespace llarp::quic diff --git a/llarp/quic/connection.hpp b/llarp/quic/connection.hpp new file mode 100644 index 000000000..b4547fa27 --- /dev/null +++ b/llarp/quic/connection.hpp @@ -0,0 +1,311 @@ +#pragma once + +#include "address.hpp" +#include "random.hpp" +#include "stream.hpp" +#include "io_result.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llarp::quic +{ + // We send and verify this in the initial connection and handshake; this is designed to allow + // future changes (by either breaking or handling backwards compat). + constexpr const std::array handshake_magic_bytes{ + 'l', 'o', 'k', 'i', 'n', 'e', 't', 0x01}; + constexpr std::basic_string_view handshake_magic{ + handshake_magic_bytes.data(), handshake_magic_bytes.size()}; + + // Flow control window sizes for a buffer and individual streams: + constexpr uint64_t CONNECTION_BUFFER = 1024 * 1024; + constexpr uint64_t STREAM_BUFFER = 64 * 1024; + // Max number of simultaneous streams we support on a connection + constexpr uint64_t STREAM_LIMIT = 100; + + using bstring_view = std::basic_string_view; + + class Endpoint; + class Server; + class Client; + + struct alignas(size_t) ConnectionID : ngtcp2_cid + { + ConnectionID() = default; + ConnectionID(const uint8_t* cid, size_t length); + ConnectionID(const ConnectionID& c) = default; + ConnectionID(ngtcp2_cid c) : ConnectionID(c.data, c.datalen) + {} + ConnectionID& + operator=(const ConnectionID& c) = default; + + static constexpr size_t + max_size() + { + return NGTCP2_MAX_CIDLEN; + } + static_assert(NGTCP2_MAX_CIDLEN <= std::numeric_limits::max()); + + bool + operator==(const ConnectionID& other) const + { + return datalen == other.datalen && std::memcmp(data, other.data, datalen) == 0; + } + bool + operator!=(const ConnectionID& other) const + { + return !(*this == other); + } + + template + static ConnectionID + random(RNG&& rng, size_t size = ConnectionID::max_size()) + { + ConnectionID r; + r.datalen = std::min(size, ConnectionID::max_size()); + random_bytes(r.data, r.datalen, rng); + return r; + } + }; + std::ostream& + operator<<(std::ostream& o, const ConnectionID& c); + +} // namespace llarp::quic +namespace std +{ + template <> + struct hash + { + // We pick our own source_cid randomly, so it's a perfectly good hash already. + size_t + operator()(const llarp::quic::ConnectionID& c) const + { + static_assert( + alignof(llarp::quic::ConnectionID) >= alignof(size_t) + && offsetof(llarp::quic::ConnectionID, data) % sizeof(size_t) == 0); + return *reinterpret_cast(c.data); + } + }; +} // namespace std +namespace llarp::quic +{ + /// Returns the current (monotonic) time as a time_point + inline auto + get_time() + { + return std::chrono::steady_clock::now(); + } + + /// Converts a time_point as returned by get_time to a nanosecond timestamp (as ngtcp2 expects). + inline uint64_t + get_timestamp(const std::chrono::steady_clock::time_point& t = get_time()) + { + return std::chrono::duration_cast(t.time_since_epoch()).count(); + } + + // Stores an established connection between server/client. + class Connection : public std::enable_shared_from_this + { + private: + struct connection_deleter + { + void + operator()(ngtcp2_conn* c) const + { + ngtcp2_conn_del(c); + } + }; + + // Packet data storage for a packet we are currently sending + std::array send_buffer{}; + size_t send_buffer_size = 0; + ngtcp2_pkt_info send_pkt_info{}; + + // Attempts to send the packet in `send_buffer`. If sending blocks then we set up a write poll + // on the socket to wait for it to become available, and return an io_result with `.blocked()` + // set to true. On other I/O errors we return the errno, and on successful sending we return a + // "true" (i.e. no error code) io_result. + io_result + send(); + + // Poll for writability; activated if we block while trying to send a packet. + std::shared_ptr wpoll; + bool wpoll_active = false; + + // Internal base method called invoked during construction to set up common client/server + // settings. dest_cid and path must already be set. + std::tuple + init(); + + // Event trigger used to queue packet processing for this connection + std::shared_ptr io_trigger; + + // Schedules a retransmit in the event loop (according to when ngtcp2 tells us we should) + void + schedule_retransmit(); + std::shared_ptr retransmit_timer; + + // The port the client wants to connect to on the server + uint16_t tunnel_port = 0; + + public: + // The endpoint that owns this connection + Endpoint& endpoint; + + /// The primary connection id of this Connection. This is the key of endpoint.conns that stores + /// the actual shared_ptr (everything else in `conns` is a weak_ptr alias). + const ConnectionID base_cid; + + /// The destination connection id we use to send to the other end; the remote end sets this as + /// the source cid in the header. + ConnectionID dest_cid; + + /// The underlying ngtcp2 connection object + std::unique_ptr conn; + + /// The most recent Path we have to/from the remote + Path path; + + /// True if we are draining (that is, we recently received a connection close from the other end + /// and should discard everything that comes in on this connection). Do not set this directly: + /// instead call Endpoint::start_draining(conn). + bool draining = false; + + /// True when we are closing; conn_buffer will contain the closing stanza. + bool closing = false; + + /// Buffer where we store non-stream connection data, e.g. for initial transport params during + /// connection and the closing stanza when disconnecting. + std::basic_string conn_buffer; + + // Stores callbacks of active streams, indexed by our local source connection ID that we assign + // when the connection is initiated. + std::map> streams; + + /// Constructs and initializes a new connection received by a Server + /// + /// \param s - the Server object on which the connection was initiated + /// \param base_cid - the local "primary" ConnectionID we use for this connection, typically + /// random \param header - packet header that initiated the connection \param path - the network + /// path to reach the remote + Connection(Server& s, const ConnectionID& base_cid, ngtcp2_pkt_hd& header, const Path& path); + + /// Establishes a connection from the local Client to a remote Server + /// \param c - the Client object from which the connection is being made + /// \param base_cid - the client's source (i.e. local) connection ID, typically random + /// \param path - the network path to reach the remote + /// \param tunnel_port - the port that this connection should tunnel to on the remote end + Connection(Client& c, const ConnectionID& scid, const Path& path, uint16_t tunnel_port); + + // Non-movable, non-copyable: + Connection(Connection&&) = delete; + Connection& + operator=(Connection&&) = delete; + Connection(const Connection&) = delete; + Connection& + operator=(const Connection&) = delete; + + ~Connection(); + + operator const ngtcp2_conn*() const + { + return conn.get(); + } + operator ngtcp2_conn*() + { + return conn.get(); + } + + // If this connection's endpoint is a server, returns a pointer to it. Otherwise returns + // nullptr. + Server* + server(); + + // If this connection's endpoint is a client, returns a pointer to it. Otherwise returs + // nullptr. + Client* + client(); + + // Called to signal libuv that this connection has stuff to do + void + io_ready(); + // Called (via libuv) when it wants us to do our stuff. Call io_ready() to schedule this. + void + on_io_ready(); + + int + setup_server_crypto_initial(); + + // Flush any streams with pending data. Note that, depending on available ngtcp2 state, we may + // not fully flush all streams -- some streams can individually block while waiting for + // confirmation. + void + flush_streams(); + + // Called when a new stream is opened + int + stream_opened(StreamID id); + + // Called when data is received for a stream + int + stream_receive(StreamID id, bstring_view data, bool fin); + + // Called when a stream is closed/reset + int + stream_reset(StreamID id, uint64_t app_error_code); + + // Called when stream data has been acknoledged and can be freed + int + stream_ack(StreamID id, size_t size); + + // Asks the endpoint for a new connection ID alias to use for this connection. cidlen can be + // used to specify the size of the cid (default is full size). + ConnectionID + make_alias_id(size_t cidlen = ConnectionID::max_size()); + + // Opens a stream over this connection; when the server receives this it attempts to establish a + // TCP connection to the tunnel configured in the connection. The data callback is invoked as + // data is received on this stream. The close callback is called if the stream is closed + // (either by the remote, or locally after a stream->close() call). + // + // \param data_cb -- callback to invoke when data is received + // \param close_cb -- callback to invoke when the connection is closed + // + // Throws a `std::runtime_error` if the stream creation fails (e.g. because the connection has + // no free stream capacity). + // + // Returns a const reference to the stored Stream shared_ptr (so that the caller can decide + // whether they want a copy or not). + const std::shared_ptr& + open_stream(Stream::data_callback_t data_cb, Stream::close_callback_t close_cb); + + // Accesses the stream via its StreamID; throws std::out_of_range if the stream doesn't exist. + const std::shared_ptr& + get_stream(StreamID s) const; + + // Internal methods that need to be publicly callable because we call them from C functions: + int + init_client(); + int + recv_initial_crypto(std::basic_string_view data); + int + recv_transport_params(std::basic_string_view data); + int + send_magic(ngtcp2_crypto_level level); + int + send_transport_params(ngtcp2_crypto_level level); + void + complete_handshake(); + }; + +} // namespace llarp::quic diff --git a/llarp/quic/endpoint.cpp b/llarp/quic/endpoint.cpp new file mode 100644 index 000000000..33a529455 --- /dev/null +++ b/llarp/quic/endpoint.cpp @@ -0,0 +1,526 @@ +#include "endpoint.hpp" +#include "client.hpp" +#include "log.hpp" +#include "server.hpp" + +#include +#include + +#include +#include + +#include + +#include + +// DEBUG: +extern "C" +{ +#include "../ngtcp2_conn.h" +} + +namespace llarp::quic +{ + Endpoint::Endpoint(std::optional
addr, std::shared_ptr loop_) + : loop{std::move(loop_)} + { + random_bytes(static_secret.data(), static_secret.size(), rng); + + // Create and bind the UDP socket. We can't use libuv's UDP socket here because it doesn't + // give us the ability to set up the ECN field as QUIC requires. + auto fd = socket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0); + if (fd == -1) + throw std::runtime_error{"Failed to open socket: "s + strerror(errno)}; + + if (addr) + { + assert(addr->sockaddr_size() == sizeof(sockaddr_in)); // FIXME: IPv4-only for now + auto rv = bind(fd, *addr, addr->sockaddr_size()); + if (rv == -1) + throw std::runtime_error{ + "Failed to bind UDP socket to " + addr->to_string() + ": " + strerror(errno)}; + } + + // Get our address via the socket in case `addr` is using anyaddr/anyport. + sockaddr_any sa; + socklen_t salen = sizeof(sa); + // FIXME: if I didn't call bind above then do I need to call bind() before this (with + // anyaddr/anyport)? + getsockname(fd, &sa.sa, &salen); + assert(salen == sizeof(sockaddr_in)); // FIXME: IPv4-only for now + local = {&sa, salen}; + Debug("Bound to ", local, addr ? "" : " (auto-selected)"); + + // Set up the socket to provide us with incoming ECN (IP_TOS) info + // NB: This is for IPv4; on AF_INET6 this would be IPPROTO_IPV6, IPV6_RECVTCLASS + if (uint8_t want_tos = 1; + - 1 + == setsockopt( + fd, IPPROTO_IP, IP_RECVTOS, &want_tos, static_cast(sizeof(want_tos)))) + throw std::runtime_error{"Failed to set ECN on socket: "s + strerror(errno)}; + + // Wire up our recv buffer structures into what recvmmsg() wants + buf.resize(max_buf_size * msgs.size()); + for (size_t i = 0; i < msgs.size(); i++) + { + auto& iov = msgs_iov[i]; + iov.iov_base = buf.data() + max_buf_size * i; + iov.iov_len = max_buf_size; +#ifdef LOKINET_HAVE_RECVMMSG + auto& mh = msgs[i].msg_hdr; +#else + auto& mh = msgs[i]; +#endif + mh.msg_name = &msgs_addr[i]; + mh.msg_namelen = sizeof(msgs_addr[i]); + mh.msg_iov = &iov; + mh.msg_iovlen = 1; + mh.msg_control = msgs_cmsg[i].data(); + mh.msg_controllen = msgs_cmsg[i].size(); + } + + // Let uv do its stuff + poll = loop->resource(fd); + poll->on([this](const auto&, auto&) { on_readable(); }); + poll->start(uvw::PollHandle::Event::READABLE); + + // Set up a callback every 250ms to clean up stale sockets, etc. + expiry_timer = loop->resource(); + expiry_timer->on([this](const auto&, auto&) { check_timeouts(); }); + expiry_timer->start(250ms, 250ms); + + Debug("Created endpoint"); + } + + Endpoint::~Endpoint() + { + if (poll) + poll->close(); + if (expiry_timer) + expiry_timer->close(); + } + + int + Endpoint::socket_fd() const + { + return poll->fd(); + } + + void + Endpoint::on_readable() + { + Debug("poll callback on readable"); + +#ifdef LOKINET_HAVE_RECVMMSG + // NB: recvmmsg is linux-specific but ought to offer some performance benefits + int n_msg = recvmmsg(socket_fd(), msgs.data(), msgs.size(), 0, nullptr); + if (n_msg == -1) + { + if (errno != EAGAIN && errno != ENOTCONN) + Warn("Error recv'ing from ", local.to_string(), ": ", strerror(errno)); + return; + } + + Debug("Recv'd ", n_msg, " messages"); + for (int i = 0; i < n_msg; i++) + { + auto& [msg_hdr, msg_len] = msgs[i]; + bstring_view data{buf.data() + i * max_buf_size, msg_len}; +#else + for (size_t i = 0; i < N_msgs; i++) + { + auto& msg_hdr = msgs[0]; + auto n_bytes = recvmsg(socket_fd(), &msg_hdr, 0); + if (n_bytes == -1 && errno != EAGAIN && errno != ENOTCONN) + Warn("Error recv'ing from ", local.to_string(), ": ", strerror(errno)); + if (n_bytes <= 0) + return; + auto msg_len = static_cast(n_bytes); + bstring_view data{buf.data(), msg_len}; +#endif + + Debug( + "header [", + msg_hdr.msg_namelen, + "]: ", + buffer_printer{reinterpret_cast(msg_hdr.msg_name), msg_hdr.msg_namelen}); + + if (!msg_hdr.msg_name || msg_hdr.msg_namelen != sizeof(sockaddr_in)) + { // FIXME: IPv6 support? + Warn("Invalid/unknown source address, dropping packet"); + continue; + } + + Packet pkt{ + Path{local, reinterpret_cast(msg_hdr.msg_name), msg_hdr.msg_namelen}, + data, + ngtcp2_pkt_info{.ecn = 0}}; + + // Go look for the ECN header field on the incoming packet + for (auto cmsg = CMSG_FIRSTHDR(&msg_hdr); cmsg; cmsg = CMSG_NXTHDR(&msg_hdr, cmsg)) + { + // IPv4; for IPv6 these would be IPPROTO_IPV6 and IPV6_TCLASS + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS && cmsg->cmsg_len) + { + pkt.info.ecn = *reinterpret_cast(CMSG_DATA(cmsg)); + } + } + + Debug( + i, + "[", + pkt.path, + ",ecn=0x", + std::hex, + +pkt.info.ecn, + std::dec, + "]: received ", + msg_len, + " bytes"); + + handle_packet(pkt); + + Debug("Done handling packet"); + +#ifdef LOKINET_HAVE_RECVMMSG // Help editor's { } matching: + } +#else + } +#endif + } + + std::optional + Endpoint::handle_packet_init(const Packet& p) + { + version_info vi; + auto rv = ngtcp2_pkt_decode_version_cid( + &vi.version, + &vi.dcid, + &vi.dcid_len, + &vi.scid, + &vi.scid_len, + u8data(p.data), + p.data.size(), + NGTCP2_MAX_CIDLEN); + if (rv == 1) + { // 1 means Version Negotiation should be sent and otherwise the packet should be ignored + send_version_negotiation(vi, p.path.remote); + return std::nullopt; + } + else if (rv != 0) + { + Warn("QUIC packet header decode failed: ", ngtcp2_strerror(rv)); + return std::nullopt; + } + + if (vi.dcid_len > ConnectionID::max_size()) + { + Warn("Internal error: destination ID is longer than should be allowed"); + return std::nullopt; + } + + return std::make_optional(vi.dcid, vi.dcid_len); + } + void + Endpoint::handle_conn_packet(Connection& conn, const Packet& p) + { + if (ngtcp2_conn_is_in_closing_period(conn)) + { + Debug("Connection is in closing period, dropping"); + close_connection(conn); + return; + } + if (conn.draining) + { + Debug("Connection is draining, dropping"); + // "draining" state means we received a connection close and we're keeping the + // connection alive just to catch (and discard) straggling packets that arrive + // out of order w.r.t to connection close. + return; + } + + if (auto result = read_packet(p, conn); !result) + { + Warn("Read packet failed! ", ngtcp2_strerror(result.error_code)); + } + + // FIXME - reset idle timer? + Debug("Done with incoming packet"); + } + + io_result + Endpoint::read_packet(const Packet& p, Connection& conn) + { + Debug("Reading packet from ", p.path); + Debug("Conn state before reading: ", conn.conn->state); + auto rv = + ngtcp2_conn_read_pkt(conn, p.path, &p.info, u8data(p.data), p.data.size(), get_timestamp()); + Debug("Conn state after reading: ", conn.conn->state); + + if (rv == 0) + conn.io_ready(); + else + Warn("read pkt error: ", ngtcp2_strerror(rv)); + + if (rv == NGTCP2_ERR_DRAINING) + start_draining(conn); + else if (rv == NGTCP2_ERR_DROP_CONN) + delete_conn(conn.base_cid); + + return {rv}; + } + + void + Endpoint::update_ecn(uint32_t ecn) + { + assert(ecn <= std::numeric_limits::max()); + if (ecn_curr != ecn) + { + if (-1 + == setsockopt(socket_fd(), IPPROTO_IP, IP_TOS, &ecn, static_cast(sizeof(ecn)))) + Warn("setsockopt failed to set IP_TOS: ", strerror(errno)); + + // IPv6 version: + // int tclass = this->ecn; + // setsockopt(socket_fd(), IPPROTO_IPV6, IPV6_TCLASS, &tclass, + // static_cast(sizeof(tclass))); + + ecn_curr = ecn; + } + } + + io_result + Endpoint::send_packet(const Address& to, bstring_view data, uint32_t ecn) + { + iovec msg_iov; + msg_iov.iov_base = const_cast(data.data()); + msg_iov.iov_len = data.size(); + + msghdr msg{}; + msg.msg_name = &const_cast(reinterpret_cast(to)); + msg.msg_namelen = sizeof(sockaddr_in); + msg.msg_iov = &msg_iov; + msg.msg_iovlen = 1; + + auto fd = socket_fd(); + + update_ecn(ecn); + ssize_t nwrite = 0; + do + { + nwrite = sendmsg(fd, &msg, 0); + } while (nwrite == -1 && errno == EINTR); + + if (nwrite == -1) + { + Warn("sendmsg failed: ", strerror(errno)); + return {errno}; + } + + Debug( + "[", + to.to_string(), + ",ecn=0x", + std::hex, + +ecn_curr, + std::dec, + "]: sent ", + nwrite, + " bytes"); + return {}; + } + + void + Endpoint::send_version_negotiation(const version_info& vi, const Address& source) + { + std::array buf; + std::array versions; + std::iota(versions.begin() + 1, versions.end(), NGTCP2_PROTO_VER_MIN); + // we're supposed to send some 0x?a?a?a?a version to trigger version negotiation + versions[0] = 0x1a2a3a4au; + + auto nwrote = ngtcp2_pkt_write_version_negotiation( + u8data(buf), + buf.size(), + std::uniform_int_distribution{0, 255}(rng), + vi.dcid, + vi.dcid_len, + vi.scid, + vi.scid_len, + versions.data(), + versions.size()); + if (nwrote < 0) + Warn("Failed to construct version negotiation packet: ", ngtcp2_strerror(nwrote)); + if (nwrote <= 0) + return; + + send_packet(source, bstring_view{buf.data(), static_cast(nwrote)}, 0); + } + + void + Endpoint::close_connection(Connection& conn, uint64_t code, bool application) + { + Debug("Closing connection ", conn.base_cid); + if (!conn.closing) + { + conn.conn_buffer.resize(max_pkt_size_v4); + Path path; + ngtcp2_pkt_info pi; + + auto write_close_func = + application ? ngtcp2_conn_write_application_close : ngtcp2_conn_write_connection_close; + auto written = write_close_func( + conn, + path, + &pi, + u8data(conn.conn_buffer), + conn.conn_buffer.size(), + code, + get_timestamp()); + if (written <= 0) + { + Warn( + "Failed to write connection close packet: ", + written < 0 ? ngtcp2_strerror(written) : "unknown error: closing is 0 bytes??"); + return; + } + assert(written <= (long)conn.conn_buffer.size()); + conn.conn_buffer.resize(written); + conn.closing = true; + + // FIXME: ipv6 + assert(path.local.sockaddr_size() == sizeof(sockaddr_in)); + assert(path.remote.sockaddr_size() == sizeof(sockaddr_in)); + + conn.path = path; + } + assert(conn.closing && !conn.conn_buffer.empty()); + + if (auto sent = send_packet(conn.path.remote, conn.conn_buffer, 0); !sent) + { + Warn( + "Failed to send packet: ", + strerror(sent.error_code), + "; removing connection ", + conn.base_cid); + delete_conn(conn.base_cid); + return; + } + } + + /// Puts a connection into draining mode (i.e. after getting a connection close). This will + /// keep the connection registered for the recommended 3*Probe Timeout, during which we drop + /// packets that use the connection id and after which we will forget about it. + void + Endpoint::start_draining(Connection& conn) + { + if (conn.draining) + return; + Debug("Putting ", conn.base_cid, " into draining mode"); + conn.draining = true; + // Recommended draining time is 3*Probe Timeout + draining.emplace(conn.base_cid, get_time() + ngtcp2_conn_get_pto(conn) * 3 * 1ns); + } + + void + Endpoint::check_timeouts() + { + auto now = get_time(); + uint64_t now_ts = get_timestamp(now); + + // Destroy any connections that are finished draining + bool cleanup = false; + while (!draining.empty() && draining.front().second < now) + { + if (auto it = conns.find(draining.front().first); it != conns.end()) + { + if (std::holds_alternative(it->second)) + cleanup = true; + Debug("Deleting connection ", it->first); + conns.erase(it); + } + draining.pop(); + } + if (cleanup) + clean_alias_conns(); + + for (auto it = conns.begin(); it != conns.end(); ++it) + { + if (auto* conn_ptr = std::get_if(&it->second)) + { + Connection& conn = **conn_ptr; + auto exp = ngtcp2_conn_get_idle_expiry(conn); + if (exp >= now_ts || conn.draining) + continue; + start_draining(conn); + } + } + } + + std::pair, bool> + Endpoint::get_conn(const ConnectionID& cid) + { + if (auto it = conns.find(cid); it != conns.end()) + { + if (auto* wptr = std::get_if(&it->second)) + return {wptr->lock(), true}; + return {var::get(it->second), false}; + } + return {nullptr, false}; + } + + bool + Endpoint::delete_conn(const ConnectionID& cid) + { + auto it = conns.find(cid); + if (it == conns.end()) + { + Debug("Cannot delete connection ", cid, ": cid not found"); + return false; + } + + bool primary = std::holds_alternative(it->second); + Debug("Deleting ", primary ? "primary" : "alias", " connection ", cid); + conns.erase(it); + if (primary) + clean_alias_conns(); + return true; + } + + void + Endpoint::clean_alias_conns() + { + for (auto it = conns.begin(); it != conns.end();) + { + if (auto* conn_wptr = std::get_if(&it->second); + conn_wptr && conn_wptr->expired()) + it = conns.erase(it); + else + ++it; + } + } + + ConnectionID + Endpoint::add_connection_id(Connection& conn, size_t cid_length) + { + ConnectionID cid; + for (bool inserted = false; !inserted;) + { + cid = ConnectionID::random(rng, cid_length); + inserted = conns.emplace(cid, conn.weak_from_this()).second; + } + Debug("Created cid ", cid, " alias for ", conn.base_cid); + return cid; + } + + void + Endpoint::make_stateless_reset_token(const ConnectionID& cid, unsigned char* dest) + { + crypto_generichash_state state; + crypto_generichash_init(&state, nullptr, 0, NGTCP2_STATELESS_RESET_TOKENLEN); + crypto_generichash_update(&state, u8data(static_secret), static_secret.size()); + crypto_generichash_update(&state, cid.data, cid.datalen); + crypto_generichash_final(&state, dest, NGTCP2_STATELESS_RESET_TOKENLEN); + } + +} // namespace llarp::quic diff --git a/llarp/quic/endpoint.hpp b/llarp/quic/endpoint.hpp new file mode 100644 index 000000000..46b53deee --- /dev/null +++ b/llarp/quic/endpoint.hpp @@ -0,0 +1,241 @@ +#pragma once + +#include "address.hpp" +#include "connection.hpp" +#include "io_result.hpp" +#include "null_crypto.hpp" +#include "packet.hpp" +#include "stream.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// True if we support recvmmsg/sendmmsg +#if defined(__linux__) && !defined(LOKINET_NO_RECVMMSG) +#define LOKINET_HAVE_RECVMMSG +#endif + +namespace llarp::quic +{ + using namespace std::literals; + + inline constexpr auto IDLE_TIMEOUT = 5min; + + class Endpoint + { + protected: + // Address we are listening on + Address local; + // The current outgoing IP ecn value for the socket + uint8_t ecn_curr = 0; + + std::shared_ptr poll; + std::shared_ptr expiry_timer; + std::shared_ptr loop; + + // How many messages (at most) we recv per callback: + static constexpr int N_msgs = 8; +#ifdef LOKINET_HAVE_RECVMMSG + static constexpr int N_mmsg = N_msgs; + std::array msgs; +#else + static constexpr int N_mmsg = 1; + std::array msgs; +#endif + + std::array msgs_iov; + std::array msgs_addr; + std::array, N_mmsg> msgs_cmsg; + std::vector buf; + // Max theoretical size of a UDP packet is 2^16-1 minus IP/UDP header overhead + static constexpr size_t max_buf_size = 64 * 1024; + // Max size of a UDP packet that we'll send + static constexpr size_t max_pkt_size_v4 = NGTCP2_MAX_PKTLEN_IPV4; + static constexpr size_t max_pkt_size_v6 = NGTCP2_MAX_PKTLEN_IPV6; + + std::mt19937_64 rng = seeded(); + + using primary_conn_ptr = std::shared_ptr; + using alias_conn_ptr = std::weak_ptr; + + // Connections. When a client establishes a new connection it chooses its own source connection + // ID and a destination connection ID and sends them to the server. + // + // This container stores the primary Connection instance as a shared_ptr, and any connection + // aliases as weak_ptrs referencing the primary instance (so that we don't have to double a + // double-hash lookup on incoming packets, since those frequently use aliases). + // + // The destination connection ID should be entirely random and can be up to 160 bits, but the + // source connection ID does not have to be (i.e. it can encode some information, if desired). + // + // The server is going to include in the response: + // - destination connection ID equal to the client's source connection ID + // - a new random source connection ID. (We don't use the client's destination ID but generate + // our own). Like the clients source ID, this can contain embedded info. + // + // The client stores this, and so we end up with client-scid == server-dcid, and client-dcid == + // server-scid, where each side chose its own source connection ID. + // + // Ultimately, we store here our own {source connection ID -> Connection} pairs (or + // equivalently, on incoming packets, the key will be the packet's dest conn ID). + std::unordered_map> conns; + + using conns_iterator = decltype(conns)::iterator; + + // Connections that are draining (i.e. we are dropping, but need to keep around for a while + // to catch and drop lagged packets). The time point is the scheduled removal time. + std::queue> draining; + + NullCrypto null_crypto; + + // Random data that we hash together with a CID to make a stateless reset token + std::array static_secret; + + friend class Connection; + + // Wires up an endpoint connection. + // + // `bind` - address we should bind to. Required for a server, optional for a client. If + // omitted, no explicit bind is performed (which means the socket will be implicitly bound to + // some OS-determined random high bind port). + // `loop` - the uv loop pointer managing polling of this endpoint + Endpoint(std::optional
bind, std::shared_ptr loop); + + virtual ~Endpoint(); + + int + socket_fd() const; + + void + on_readable(); + + // Version & connection id info that we can potentially extract when decoding a packet + struct version_info + { + uint32_t version; + const uint8_t* dcid; + size_t dcid_len; + const uint8_t* scid; + size_t scid_len; + }; + + // Called to handle an incoming packet + virtual void + handle_packet(const Packet& p) = 0; + + // Internal method: handles initial common packet decoding, returns the connection ID or nullopt + // if decoding failed. + std::optional + handle_packet_init(const Packet& p); + // Internal method: handles a packet sent to the given connection + void + handle_conn_packet(Connection& c, const Packet& p); + + // Reads a packet and handles various error conditions. Returns an io_result. Note that it is + // possible for the conn_it to be erased from `conns` if the error code is anything other than + // success (0) or NGTCP2_ERR_RETRY. + io_result + read_packet(const Packet& p, Connection& conn); + + // Sets up the ECN IP field (IP_TOS for IPv4) for the next outgoing packet sent via + // send_packet(). This does the actual syscall (if ECN is different than currently set), and is + // typically called implicitly via send_packet(). + void + update_ecn(uint32_t ecn); + + // Sends a packet to `to` containing `data`. Returns a non-error io_result on success, + // an io_result with .error_code set to the errno of the failure on failure. + io_result + send_packet(const Address& to, bstring_view data, uint32_t ecn); + + // Wrapper around the above that takes a regular std::string_view (i.e. of chars) and recasts + // it to an string_view of std::bytes. + io_result + send_packet(const Address& to, std::string_view data, uint32_t ecn) + { + return send_packet( + to, bstring_view{reinterpret_cast(data.data()), data.size()}, ecn); + } + + // Another wrapper taking a vector + io_result + send_packet(const Address& to, const std::vector& data, uint32_t ecn) + { + return send_packet(to, bstring_view{data.data(), data.size()}, ecn); + } + + void + send_version_negotiation(const version_info& vi, const Address& source); + + // Looks up a connection. Returns a shared_ptr (either copied for a primary connection, or + // locked from an alias's weak pointer) if the connection was found or nullptr if not; and a + // bool indicating whether this connection ID was an alias (true) or not (false). [Note: the + // alias value can be true even if the shared_ptr is null in the case of an expired alias that + // hasn't yet been cleaned up]. + std::pair, bool> + get_conn(const ConnectionID& cid); + + // Called to start closing (or continue closing) a connection by sending a connection close + // response to any incoming packets. + // + // Takes the iterator to the connection pair from `conns` and optional error parameters: if + // `application` is false (the default) then we do a hard connection close because of transport + // error, if true we do a graceful application close. For application closes the code is + // application-defined; for hard closes the code should be one of the NGTCP2_*_ERROR values. + void + close_connection(Connection& conn, uint64_t code = NGTCP2_NO_ERROR, bool application = false); + + /// Puts a connection into draining mode (i.e. after getting a connection close). This will + /// keep the connection registered for the recommended 3*Probe Timeout, during which we drop + /// packets that use the connection id and after which we will forget about it. + void + start_draining(Connection& conn); + + void + check_timeouts(); + + /// Deletes a connection from `conns`; if the connecion is a primary connection shared pointer + /// then it is removed and clean_alias_conns() is immediately called to remove any aliases to + /// the connection. If the given connection is an alias connection then it is removed but no + /// cleanup is performed. Returns true if something was removed, false if the connection was + /// not found. + bool + delete_conn(const ConnectionID& cid); + + /// Removes any connection id aliases that no longer have associated Connections. + void + clean_alias_conns(); + + /// Creates a new, unused connection ID alias for the given connection; adds the alias to + /// `conns` and returns the ConnectionID. + ConnectionID + add_connection_id(Connection& conn, size_t cid_length = ConnectionID::max_size()); + + public: + // Makes a deterministic stateless reset token for the given connection ID. Writes it to dest + // (which must have NGTCP2_STATELESS_RESET_TOKENLEN bytes available). + void + make_stateless_reset_token(const ConnectionID& cid, unsigned char* dest); + + // Default stream buffer size for streams opened through this endpoint. + size_t default_stream_buffer_size = 64 * 1024; + + // Gets a reference to the UV event loop + uvw::Loop& + get_loop() + { + return *loop; + } + }; + +} // namespace llarp::quic diff --git a/llarp/quic/io_result.hpp b/llarp/quic/io_result.hpp new file mode 100644 index 000000000..00eb87734 --- /dev/null +++ b/llarp/quic/io_result.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace llarp::quic +{ + // Return result from a read or write operation that wraps an errno value. It is implicitly + // convertible to bool to test for "is not an error" (which is the inverse of casting a plain + // integer error code value to bool). + struct io_result + { + // An error code, typically an errno value + int error_code{0}; + // Returns true if this represent a successful result, i.e. an error_code of 0. + operator bool() const + { + return error_code == 0; + } + + // Returns true if this is an error value indicating a failure to write without blocking (only + // applied to io_result's capturing an errno). + bool + blocked() const + { + return error_code == EAGAIN || error_code == EWOULDBLOCK; + } + + // Returns the errno string for the given error code. + std::string_view + str() const + { + return strerror(error_code); + } + }; + +} // namespace llarp::quic diff --git a/llarp/quic/log.cpp b/llarp/quic/log.cpp new file mode 100644 index 000000000..3d28505a5 --- /dev/null +++ b/llarp/quic/log.cpp @@ -0,0 +1,45 @@ +#include "log.hpp" + +namespace llarp::quic +{ + std::ostream& + operator<<(std::ostream& o, const buffer_printer& bp) + { + auto& b = bp.buf; + auto oldfill = o.fill(); + o.fill('0'); + o << "Buffer[" << b.size() << "/0x" << std::hex << b.size() << " bytes]:"; + for (size_t i = 0; i < b.size(); i += 32) + { + o << "\n" << std::setw(4) << i << " "; + + size_t stop = std::min(b.size(), i + 32); + for (size_t j = 0; j < 32; j++) + { + auto k = i + j; + if (j % 4 == 0) + o << ' '; + if (k >= stop) + o << " "; + else + o << std::setw(2) << std::to_integer(b[k]); + } + o << u8" ┃"; + for (size_t j = i; j < stop; j++) + { + auto c = std::to_integer(b[j]); + if (c == 0x00) + o << u8"∅"; + else if (c < 0x20 || c > 0x7e) + o << u8"·"; + else + o << c; + } + o << u8"┃"; + } + o << std::dec; + o.fill(oldfill); + return o; + } + +} // namespace llarp::quic diff --git a/llarp/quic/log.hpp b/llarp/quic/log.hpp new file mode 100644 index 000000000..cf536b418 --- /dev/null +++ b/llarp/quic/log.hpp @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Temporary logging code to be replaced with lokinet logging + +#include + +#ifdef __cpp_lib_source_location +#include +namespace slns = std; +#else +#include +namespace slns = std::experimental; +#endif + +namespace llarp::quic +{ + struct buffer_printer + { + std::basic_string_view buf; + + template > + explicit buffer_printer(std::basic_string_view buf) + : buf{reinterpret_cast(buf.data()), buf.size()} + {} + + template > + explicit buffer_printer(const std::basic_string& buf) + : buffer_printer(std::basic_string_view{buf}) + {} + + template > + explicit buffer_printer(std::basic_string&& buf) = delete; + + template > + explicit buffer_printer(const T* data, size_t size) + : buffer_printer(std::basic_string_view{data, size}) + {} + }; + std::ostream& + operator<<(std::ostream& o, const buffer_printer& bp); + + namespace detail + { + template + constexpr bool is_same_any_v = (std::is_same_v || ...); + + template + void + log_print_vals(T&& val, More&&... more) + { + using PlainT = std::remove_reference_t; + if constexpr (is_same_any_v) + std::cerr + << +val; // Promote chars to int so that they get printed as numbers, not literal chars + else + std::cerr << val; + if constexpr (sizeof...(More)) + log_print_vals(std::forward(more)...); + } + + template + void + log_print(const slns::source_location& location, T&&... args) + { + std::string_view filename{location.file_name()}; + if (auto pos = filename.rfind('/'); pos != std::string::npos + && (pos = filename.substr(0, pos).rfind('/')) != std::string::npos) + { + filename.remove_prefix(pos + 1); + } + std::cerr << "\e[3m[" << filename << ':' << location.line() << "]\e[23m"; + if constexpr (sizeof...(T)) + { + std::cerr << ": "; + detail::log_print_vals(std::forward(args)...); + } + std::cerr << '\n'; + } + + } // namespace detail + +#ifndef NDEBUG + template + struct Debug + { + Debug(T&&... args, const slns::source_location& location = slns::source_location::current()) + { + std::cerr << "DBG"; + detail::log_print(location, std::forward(args)...); + } + }; + template + Debug(T&&...) -> Debug; +#else + template + void + Debug(T&&...) + {} +#endif + + template + struct Info + { + Info(T&&... args, const slns::source_location& location = slns::source_location::current()) + { + std::cerr << "\e[32mNFO"; + detail::log_print(location, std::forward(args)...); + std::cerr << "\e[0m"; + } + }; + template + Info(T&&...) -> Info; + + template + struct Warn + { + Warn(T&&... args, const slns::source_location& location = slns::source_location::current()) + { + std::cerr << "\e[33;1mWRN"; + detail::log_print(location, std::forward(args)...); + std::cerr << "\e[0m"; + } + }; + template + Warn(T&&...) -> Warn; + + template + struct Error + { + Error(T&&... args, const slns::source_location& location = slns::source_location::current()) + { + std::cerr << "\e[31;1mWRN"; + detail::log_print(location, std::forward(args)...); + std::cerr << "\e[0m"; + } + }; + template + Error(T&&...) -> Error; + +} // namespace llarp::quic diff --git a/llarp/quic/null_crypto.cpp b/llarp/quic/null_crypto.cpp new file mode 100644 index 000000000..98f74b6a7 --- /dev/null +++ b/llarp/quic/null_crypto.cpp @@ -0,0 +1,93 @@ +#include "null_crypto.hpp" +#include "log.hpp" + +#include + +namespace llarp::quic +{ + // Cranks a value to "11", i.e. set it to its maximum + template + void + crank_to_eleven(T& val) + { + val = std::numeric_limits::max(); + } + + NullCrypto::NullCrypto() + { + crank_to_eleven(null_ctx.max_encryption); + crank_to_eleven(null_ctx.max_decryption_failure); + null_ctx.aead.max_overhead = 1; // Fails an assertion if 0 + null_aead.max_overhead = 1; // FIXME - can this be 0? + } + + void + NullCrypto::client_initial(Connection& conn) + { + ngtcp2_conn_set_initial_crypto_ctx(conn, &null_ctx); + ngtcp2_conn_install_initial_key( + conn, + &null_aead_ctx, + null_iv.data(), + &null_cipher_ctx, + &null_aead_ctx, + null_iv.data(), + &null_cipher_ctx, + null_iv.size()); + ngtcp2_conn_set_retry_aead(conn, &null_aead, &null_aead_ctx); + ngtcp2_conn_set_crypto_ctx(conn, &null_ctx); + } + + void + NullCrypto::server_initial(Connection& conn) + { + Debug("Server initial null crypto setup"); + ngtcp2_conn_set_initial_crypto_ctx(conn, &null_ctx); + ngtcp2_conn_install_initial_key( + conn, + &null_aead_ctx, + null_iv.data(), + &null_cipher_ctx, + &null_aead_ctx, + null_iv.data(), + &null_cipher_ctx, + null_iv.size()); + ngtcp2_conn_set_crypto_ctx(conn, &null_ctx); + } + + bool + NullCrypto::install_tx_handshake_key(Connection& conn) + { + return ngtcp2_conn_install_tx_handshake_key( + conn, &null_aead_ctx, null_iv.data(), null_iv.size(), &null_cipher_ctx) + == 0; + } + bool + NullCrypto::install_rx_handshake_key(Connection& conn) + { + return ngtcp2_conn_install_rx_handshake_key( + conn, &null_aead_ctx, null_iv.data(), null_iv.size(), &null_cipher_ctx) + == 0; + } + bool + NullCrypto::install_tx_key(Connection& conn) + { + return ngtcp2_conn_install_tx_key( + conn, + null_iv.data(), + null_iv.size(), + &null_aead_ctx, + null_iv.data(), + null_iv.size(), + &null_cipher_ctx) + == 0; + } + bool + NullCrypto::install_rx_key(Connection& conn) + { + return ngtcp2_conn_install_rx_key( + conn, nullptr, 0, &null_aead_ctx, null_iv.data(), null_iv.size(), &null_cipher_ctx) + == 0; + } + +} // namespace llarp::quic diff --git a/llarp/quic/null_crypto.hpp b/llarp/quic/null_crypto.hpp new file mode 100644 index 000000000..9948b2564 --- /dev/null +++ b/llarp/quic/null_crypto.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "connection.hpp" + +#include +#include + +#include + +namespace llarp::quic +{ + // Class providing do-nothing stubs for quic crypto operations: everything over lokinet is already + // encrypted so we just no-op QUIC's built in crypto operations. + struct NullCrypto + { + NullCrypto(); + + void + client_initial(Connection& conn); + + void + server_initial(Connection& conn); + + bool + install_tx_handshake_key(Connection& conn); + bool + install_tx_key(Connection& conn); + + bool + install_rx_handshake_key(Connection& conn); + bool + install_rx_key(Connection& conn); + + private: + std::array null_iv{}; + // std::array null_data{}; + + ngtcp2_crypto_ctx null_ctx{}; + ngtcp2_crypto_aead null_aead{}; + ngtcp2_crypto_aead_ctx null_aead_ctx{}; + ngtcp2_crypto_cipher_ctx null_cipher_ctx{}; + }; + +} // namespace llarp::quic diff --git a/llarp/quic/packet.hpp b/llarp/quic/packet.hpp new file mode 100644 index 000000000..a50d283a9 --- /dev/null +++ b/llarp/quic/packet.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "connection.hpp" + +namespace llarp::quic +{ + // Encapsulates a packet, i.e. the remote addr, packet data, plus metadata. + struct Packet + { + Path path; + bstring_view data; + ngtcp2_pkt_info info; + }; + +} // namespace llarp::quic diff --git a/llarp/quic/random.hpp b/llarp/quic/random.hpp new file mode 100644 index 000000000..01f040486 --- /dev/null +++ b/llarp/quic/random.hpp @@ -0,0 +1,37 @@ +#pragma once + +// TODO: replace with llarp + +#include +#include +#include +#include + +template +void +random_bytes(void* dest, size_t length, Gen&& rng) +{ + using RNG = std::remove_reference_t; + using UInt = typename RNG::result_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(RNG::min() == 0 && RNG::max() == std::numeric_limits::max()); + auto* d = reinterpret_cast(dest); + for (size_t o = 0; o < length; o += sizeof(UInt)) + { + UInt x = rng(); + std::memcpy(d + o, &x, std::min(sizeof(UInt), length - o)); + } +} + +// Returns an RNG with a fully seeded state from std::random_device +template +RNG +seeded() +{ + constexpr size_t rd_draws = + ((RNG::state_size * sizeof(typename RNG::result_type) - 1) / sizeof(unsigned int) + 1); + std::array seed_data; + std::generate(seed_data.begin(), seed_data.end(), std::random_device{}); + std::seed_seq seed(seed_data.begin(), seed_data.end()); + return RNG{seed}; +} diff --git a/llarp/quic/server.cpp b/llarp/quic/server.cpp new file mode 100644 index 000000000..5dae1f0b5 --- /dev/null +++ b/llarp/quic/server.cpp @@ -0,0 +1,117 @@ +#include "server.hpp" +#include "log.hpp" + +#include +#include +#include + +#include +#include +#include + +namespace llarp::quic +{ + Server::Server( + Address listen, std::shared_ptr loop, stream_open_callback_t stream_open) + : Endpoint{std::move(listen), std::move(loop)}, stream_open_callback{std::move(stream_open)} + {} + + void + Server::handle_packet(const Packet& p) + { + Debug("Handling incoming server packet: ", buffer_printer{p.data}); + auto maybe_dcid = handle_packet_init(p); + if (!maybe_dcid) + return; + auto& dcid = *maybe_dcid; + + // See if we have an existing connection already established for it + Debug("Incoming connection id ", dcid); + primary_conn_ptr connptr; + if (auto conn_it = conns.find(dcid); conn_it != conns.end()) + { + if (auto* wptr = std::get_if(&conn_it->second)) + { + connptr = wptr->lock(); + if (!connptr) + Debug("CID is an expired alias"); + else + Debug("CID is an alias for primary CID ", connptr->base_cid); + } + else + { + connptr = var::get(conn_it->second); + Debug("CID is primary"); + } + } + else + { + connptr = accept_connection(p); + } + + if (!connptr) + { + Warn("invalid or expired connection, ignoring"); + return; + } + + handle_conn_packet(*connptr, p); + } + + std::shared_ptr + Server::accept_connection(const Packet& p) + { + Debug("Accepting new connection"); + // This is a new incoming connection + ngtcp2_pkt_hd hd; + auto rv = ngtcp2_accept(&hd, u8data(p.data), p.data.size()); + + if (rv == -1) + { // Invalid packet + Warn("Invalid packet received, length=", p.data.size()); +#ifndef NDEBUG + Debug("packet body:"); + for (size_t i = 0; i < p.data.size(); i += 50) + Debug(" ", oxenmq::to_hex(p.data.substr(i, 50))); +#endif + return nullptr; + } + + if (rv == 1) + { // Invalid/unexpected version, send a version negotiation + Debug("Invalid/unsupported version; sending version negotiation"); + send_version_negotiation( + version_info{hd.version, hd.dcid.data, hd.dcid.datalen, hd.scid.data, hd.scid.datalen}, + p.path.remote); + return nullptr; + } + + /* + ngtcp2_cid ocid; + ngtcp2_cid *pocid = nullptr; + */ + if (hd.type == NGTCP2_PKT_0RTT) + { + Warn("Received 0-RTT packet, which shouldn't happen in our implementation; dropping"); + return nullptr; + } + else if (hd.type == NGTCP2_PKT_INITIAL && hd.token.len) + { + // This is a normal QUIC thing, but we don't do it: + Warn("Unexpected token in initial packet"); + } + + // create and store Connection + for (;;) + { + if (auto [it, ins] = conns.emplace(ConnectionID::random(rng), primary_conn_ptr{}); ins) + { + auto connptr = std::make_shared(*this, it->first, hd, p.path); + it->second = connptr; + Debug("Created local Connection ", it->first, " for incoming connection"); + return connptr; + } + } + } + +} // namespace llarp::quic diff --git a/llarp/quic/server.hpp b/llarp/quic/server.hpp new file mode 100644 index 000000000..e6973ed55 --- /dev/null +++ b/llarp/quic/server.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include "endpoint.hpp" + +#include + +namespace llarp::quic +{ + class Server : public Endpoint + { + public: + using stream_open_callback_t = + std::function; + + Server(Address listen, std::shared_ptr loop, stream_open_callback_t stream_opened); + + // Stream callback: takes the server, the (just-created) stream, and the connection port. + // Returns true if the stream should be allowed or false to reject the stream. The callback + // should set up the data_callback and close_callback on the stream: they will default to null + // (which means incoming data will simply be dropped). + stream_open_callback_t stream_open_callback; + + int + setup_null_crypto(ngtcp2_conn* conn); + + private: + // Handles an incoming packet by figuring out and handling the connection id; if necessary we + // send back a version negotiation or a connection close frame, or drop the packet (if in the + // draining state). If we get through all of the above then it's a packet to read, in which + // case we pass it on to read_packet(). + void + handle_packet(const Packet& p) override; + + // Creates a new connection from an incoming packet. Returns a nullptr if the connection can't + // be created. + std::shared_ptr + accept_connection(const Packet& p); + }; + +} // namespace llarp::quic diff --git a/llarp/quic/stream.cpp b/llarp/quic/stream.cpp new file mode 100644 index 000000000..1b473923e --- /dev/null +++ b/llarp/quic/stream.cpp @@ -0,0 +1,336 @@ +#include "stream.hpp" +#include "connection.hpp" +#include "endpoint.hpp" +#include "log.hpp" + +#include +#include + +// We use a single circular buffer with a pointer to the starting byte (denoted `á` or `ŕ`), the +// overall size, and the number of sent-but-unacked bytes (denoted `a`). `r` denotes an unsent +// byte. +// [ áaaaaaaarrrr ] +// ^ == start +// ------------ == size (== unacked + unsent bytes) +// -------- == unacked_size +// ^ -- the next write starts here +// ^^^^^^^ ^^^^^^^ -- unused buffer space +// +// we give ngtcp2 direct control over the unacked part of this buffer (it will let us know once the +// buffered data is no longer needed, i.e. once it is acknowledged by the remote side). +// +// The complication is that this buffer wraps, so if we write a bunch of data to the above it would +// end up looking like this: +// +// [rrr áaaaaaaarrrrrrrrrrr] +// +// This complicates things a bit, especially when returning the buffer to be written because we +// might have to return two separate string_views (the first would contain [rrrrrrrrrrr] and the +// second would contain [rrr]). As soon as we pass those buffer pointers off to ngtcp2 then our +// buffer looks like: +// +// [aaa áaaaaaaaaaaaaaaaaaa] +// +// Once we get an acknowledgement from the other end of the QUIC connection we can move up B (the +// beginning of the buffer); for example, suppose it acknowledges the next 10 bytes and then the +// following 10; we'll have: +// +// [aaa áaaaaaaaa] -- first 10 acked +// [ áa ] -- next 10 acked +// +// As a special case, if the buffer completely empties (i.e. all data is sent and acked) then we +// reset the starting bytes to the beginning of the buffer. + +namespace llarp::quic +{ + std::ostream& + operator<<(std::ostream& o, const StreamID& s) + { + return o << u8"Str❰" << s.id << u8"❱"; + } + + Stream::Stream( + Connection& conn, + data_callback_t data_cb, + close_callback_t close_cb, + size_t buffer_size, + StreamID id) + : data_callback{std::move(data_cb)} + , close_callback{std::move(close_cb)} + , conn{conn} + , stream_id{std::move(id)} + , buffer{buffer_size} + , avail_trigger{conn.endpoint.get_loop().resource()} + { + avail_trigger->on([this](auto&, auto&) { handle_unblocked(); }); + } + + Stream::Stream(Connection& conn, StreamID id, size_t buffer_size) + : Stream{conn, nullptr, nullptr, buffer_size, std::move(id)} + {} + + void + Stream::set_buffer_size(size_t size) + { + if (used() != 0) + throw std::runtime_error{"Cannot update buffer size while buffer is in use"}; + if (size > 0 && size < 2048) + size = 2048; + + buffer.resize(size); + buffer.shrink_to_fit(); + start = size = unacked_size = 0; + } + + size_t + Stream::buffer_size() const + { + return buffer.empty() ? size + start // start is the acked amount of the first buffer + : buffer.size(); + } + + bool + Stream::append(bstring_view data) + { + assert(!buffer.empty()); + + if (data.size() > available()) + return false; + + // When we are appending we have three cases: + // - data doesn't fit -- we simply abort (return false, above). + // - data fits between the buffer end and `]` -- simply append it and update size + // - data is larger -- copy from the end up to `]`, then copy the rest into the beginning of the + // buffer (i.e. after `[`). + + size_t wpos = (start + size) % buffer.size(); + if (wpos + data.size() > buffer.size()) + { + // We are wrapping + auto data_split = data.begin() + (buffer.size() - wpos); + std::copy(data.begin(), data_split, buffer.begin() + wpos); + std::copy(data_split, data.end(), buffer.begin()); + Debug( + "Wrote ", + data.size(), + " bytes to buffer ranges [", + wpos, + ",", + buffer.size(), + ")+[0,", + data.end() - data_split, + ")"); + } + else + { + // No wrap needs, it fits before the end: + std::copy(data.begin(), data.end(), buffer.begin() + wpos); + Debug("Wrote ", data.size(), " bytes to buffer range [", wpos, ",", wpos + data.size(), ")"); + } + size += data.size(); + Debug("New stream buffer: ", size, "/", buffer.size(), " bytes beginning at ", start); + conn.io_ready(); + return true; + } + size_t + Stream::append_any(bstring_view data) + { + if (size_t avail = available(); data.size() > avail) + data.remove_suffix(data.size() - avail); + [[maybe_unused]] bool appended = append(data); + assert(appended); + return data.size(); + } + + void + Stream::append_buffer(const std::byte* buffer, size_t length) + { + assert(this->buffer.empty()); + user_buffers.emplace_back(buffer, length); + size += length; + conn.io_ready(); + } + + void + Stream::acknowledge(size_t bytes) + { + // Frees bytes; e.g. acknowledge(3) changes: + // [ áaaaaarr ] to [ áaarr ] + // [aaarr áa] to [ áarr ] + // [ áaarrr ] to [ ŕrr ] + // [ áaa ] to [´ ] (i.e. empty buffer *and* reset start pos) + // + assert(bytes <= unacked_size && unacked_size <= size); + + Debug("Acked ", bytes, " bytes of ", unacked_size, "/", size, " unacked/total"); + + unacked_size -= bytes; + size -= bytes; + if (!buffer.empty()) + start = size == 0 ? 0 + : (start + bytes) + % buffer.size(); // reset start to 0 (to reduce wrapping buffers) if empty + else if (size == 0) + { + user_buffers.clear(); + start = 0; + } + else + { + while (bytes) + { + assert(!user_buffers.empty()); + assert(start < user_buffers.front().second); + if (size_t remaining = user_buffers.front().second - start; bytes >= remaining) + { + user_buffers.pop_front(); + start = 0; + bytes -= remaining; + } + else + { + start += bytes; + bytes = 0; + } + } + } + + if (!unblocked_callbacks.empty()) + available_ready(); + } + + auto + get_buffer_it( + std::deque, size_t>>& bufs, size_t offset) + { + auto it = bufs.begin(); + while (offset >= it->second) + { + offset -= it->second; + it++; + } + return std::make_pair(std::move(it), offset); + } + + std::vector + Stream::pending() + { + std::vector bufs; + size_t rsize = unsent(); + if (!rsize) + return bufs; + if (!buffer.empty()) + { + size_t rpos = (start + unacked_size) % buffer.size(); + if (size_t rend = rpos + rsize; rend <= buffer.size()) + { + bufs.emplace_back(buffer.data() + rpos, rsize); + } + else + { // wrapping + bufs.reserve(2); + bufs.emplace_back(buffer.data() + rpos, buffer.size() - rpos); + bufs.emplace_back(buffer.data(), rend % buffer.size()); + } + } + else + { + assert(!user_buffers.empty()); // If empty then unsent() should have been 0 + auto [it, offset] = get_buffer_it(user_buffers, start + unacked_size); + bufs.reserve(std::distance(it, user_buffers.end())); + assert(it != user_buffers.end()); + bufs.emplace_back(it->first.get() + offset, it->second - offset); + for (++it; it != user_buffers.end(); ++it) + bufs.emplace_back(it->first.get(), it->second); + } + return bufs; + } + + void + Stream::when_available(unblocked_callback_t unblocked_cb) + { + assert(available() == 0); + unblocked_callbacks.push(std::move(unblocked_cb)); + } + + void + Stream::handle_unblocked() + { + if (buffer.empty()) + { + while (!unblocked_callbacks.empty() && unblocked_callbacks.front()(*this)) + unblocked_callbacks.pop(); + } + while (!unblocked_callbacks.empty() && available() > 0) + { + if (unblocked_callbacks.front()(*this)) + unblocked_callbacks.pop(); + else + assert(available() == 0); + } + conn.io_ready(); + } + + void + Stream::io_ready() + { + conn.io_ready(); + } + + void + Stream::available_ready() + { + avail_trigger->send(); + } + + void + Stream::wrote(size_t bytes) + { + // Called to tell us we sent some bytes off, e.g. wrote(3) changes: + // [ áaarrrrrr ] or [rr áaar] + // to: + // [ áaaaaarrr ] or [aa áaaa] + Debug("wrote ", bytes, ", unsent=", unsent()); + assert(bytes <= unsent()); + unacked_size += bytes; + } + + void + Stream::close(std::optional error_code) + { + Debug( + "Closing ", + stream_id, + error_code ? " immediately with code " + std::to_string(*error_code) : " gracefully"); + + if (is_shutdown) + Debug("Stream is already shutting down"); + else if (error_code) + { + is_closing = is_shutdown = true; + ngtcp2_conn_shutdown_stream(conn, stream_id.id, *error_code); + } + else if (is_closing) + Debug("Stream is already closing"); + else + is_closing = true; + + if (is_shutdown) + data_callback = {}; + + conn.io_ready(); + } + + void + Stream::data(std::shared_ptr data) + { + user_data = std::move(data); + } + + void + Stream::weak_data(std::weak_ptr data) + { + user_data = std::move(data); + } + +} // namespace llarp::quic diff --git a/llarp/quic/stream.hpp b/llarp/quic/stream.hpp new file mode 100644 index 000000000..60ba4b377 --- /dev/null +++ b/llarp/quic/stream.hpp @@ -0,0 +1,343 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace llarp::quic +{ + class Connection; + + using bstring_view = std::basic_string_view; + + // Shortcut for a const-preserving `reinterpret_cast`ing c.data() from a std::byte to a uint8_t + // pointer, because we need it all over the place in the ngtcp2 API and I'd rather deal with + // std::byte's out here for type safety. + template < + typename Container, + typename = std::enable_if_t< + sizeof(typename std::remove_reference_t::value_type) == sizeof(uint8_t)>> + inline auto* + u8data(Container&& c) + { + using u8_sameconst_t = std::conditional_t< + std::is_const_v>, + const uint8_t, + uint8_t>; + return reinterpret_cast(c.data()); + } + + // Type-safe wrapper around a int64_t stream id. Default construction is ngtcp2's special + // "no-stream" id. + struct StreamID + { + int64_t id{-1}; + bool + operator==(const StreamID& s) const + { + return s.id == id; + } + bool + operator!=(const StreamID& s) const + { + return s.id != id; + } + bool + operator<(const StreamID& s) const + { + return s.id < id; + } + bool + operator<=(const StreamID& s) const + { + return s.id <= id; + } + bool + operator>(const StreamID& s) const + { + return s.id > id; + } + bool + operator>=(const StreamID& s) const + { + return s.id >= id; + } + }; + + // Application error code we close with if the data handle throws + constexpr uint64_t STREAM_EXCEPTION_ERROR_CODE = (1ULL << 62) - 2; + + std::ostream& + operator<<(std::ostream& o, const StreamID& s); +} // namespace llarp::quic + +namespace std +{ + template <> + struct hash + { + size_t + operator()(const llarp::quic::StreamID& s) const + { + return std::hash{}(s.id); + } + }; +} // namespace std + +namespace llarp::quic +{ + // Class for an established stream (a single connection has multiple streams): we have a + // fixed-sized ring buffer for holding outgoing data, and a callback to invoke on received data. + // To construct a Stream call `conn.open_stream()`. + class Stream : public std::enable_shared_from_this + { + public: + // Returns the StreamID of this stream + const StreamID& + id() const + { + return stream_id; + } + + // Sets the size of the outgoing data buffer. This may *only* be used if the buffer is + // currently entirely empty; otherwise a runtime_error is thrown. The minimum buffer size is + // 2048, the default is 64kiB. A value of 0 puts the Stream into user-provided buffer mode + // where only the version of `append` taking ownership of a char* is permitted. + void + set_buffer_size(size_t size); + + // Returns the size of the buffer (including both pending and free space). If using + // user-provided buffer mode then this is the sum of all held buffers. + size_t + buffer_size() const; + + // Returns the number of free bytes available in the outgoing stream data buffer. Always 0 in + // user-provided buffer mode. + size_t + available() const + { + return is_closing || buffer.empty() ? 0 : buffer.size() - size; + } + + // Returns the number of bytes currently referenced in the buffer (i.e. pending or + // sent-but-unacknowledged). + size_t + used() const + { + return size; + } + + // Returns the number of bytes of the buffer that have been sent but not yet acknowledged and + // thus are still required. + size_t + unacked() const + { + return unacked_size; + } + + // Returns the number of bytes of the buffer that have not yet been sent. + size_t + unsent() const + { + return used() - unacked(); + } + + // Try to append all of the given bytes to the outgoing stream data buffer. Returns true if + // successful, false (without appending anything) if there is insufficient space. If you want + // to append as much as possible then use `append_any` instead. + bool + append(bstring_view data); + bool + append(std::string_view data) + { + return append(bstring_view{reinterpret_cast(data.data()), data.size()}); + } + + // Append bytes to the outgoing stream data buffer, allowing partial consumption of data if the + // entire provided data cannot be appended. Returns the number of appended bytes (which will be + // less than the total provided if the provided data is larger than `available()`). If you want + // an all-or-nothing append then use `append` instead. + size_t + append_any(bstring_view data); + size_t + append_any(std::string_view data) + { + return append_any(bstring_view{reinterpret_cast(data.data()), data.size()}); + } + + // Takes ownership of the given buffer pointer, queuing it to be sent after any existing buffers + // and freed once fully acked. You *must* have called `set_buffer_size(0)` (or set the + // endpoints default_stream_buffer_size to 0) in order to use this. + void + append_buffer(const std::byte* buf, size_t length); + + // Starting closing the stream and prevent any more outgoing data from being appended. If + // `error_code` is provided then we close immediately with the given code; if std::nullopt (the + // default) we close gracefully by sending a FIN bit. + void + close(std::optional error_code = std::nullopt); + + // Returns true if this Stream is closing (or already closed). + bool + closing() const + { + return is_closing; + } + + // Callback invoked when data is received + using data_callback_t = std::function; + + // Callback invoked when the stream is closed + using close_callback_t = std::function error_code)>; + + // Callback invoked when free stream buffer space becomes available. Should return true if the + // callback is finished and can be discarded, false if the callback is still needed. If + // returning false then it *must* have filled the stream's outgoing buffer (this is asserted in + // a debug build). + using unblocked_callback_t = std::function; + + // Callback to invoke when we receive some incoming data; there's no particular guarantee on the + // size of the data, just that this will always be called in sequential order. + data_callback_t data_callback; + + // Callback to invoke when the connection has closed. If the close was an abrupt stream close + // initiated by the remote then `error_code` will be set to whatever code the remote side + // provided; for graceful closing or locally initiated closing the error code will be null. + close_callback_t close_callback; + + // Queues a callback to be invoked when space becomes available for writing in the buffer. The + // callback should true if it completed, false if it still needs more buffer space. If multiple + // callbacks are queued they are invoked in order, space permitting. The stored std::function + // will not be moved or copied after being invoked (i.e. if invoked multiple times it will + // always be invoked on the same instance). + // + // Available callbacks should only be used when the buffer is full, typically immediately after + // an `append_any` call that returns less than the full write. Similarly a false return from an + // unblock function (which keeps the callback alive) should satisfy the same condition. + // + // In user-provided buffer mode the callback will be invoked after any data has been acked: it + // is up to the caller to look at used()/buffer_size()/etc. to decide what to do. As described + // above, return true to remove this callback, false to keep it and try again after the next + // ack. + void + when_available(unblocked_callback_t unblocked_cb); + + // Calls io_ready() on the stream's connection to scheduling sending outbound data + void + io_ready(); + + // Schedules processing of the "when_available" callbacks + void + available_ready(); + + // Lets you stash some arbitrary data in a shared_ptr; this is not used internally. + void + data(std::shared_ptr data); + + // Variation of data() that holds the pointer in a weak_ptr instead of a shared_ptr. + void + weak_data(std::weak_ptr data); + + // Retrieves the stashed data, with a static_cast to the desired type. This is used for + // retrieval of both shared or weak data types (if held as a weak_ptr it is lock()ed first). + template + std::shared_ptr + data() const + { + return std::static_pointer_cast( + std::holds_alternative>(user_data) + ? std::get>(user_data) + : std::get>(user_data).lock()); + } + + private: + friend class Connection; + + Stream( + Connection& conn, + data_callback_t data_cb, + close_callback_t close_cb, + size_t buffer_size, + StreamID id = {-1}); + Stream(Connection& conn, StreamID id, size_t buffer_size); + + // Non-copyable, non-movable; we manage it via a unique_ptr held by its Connection + Stream(const Stream&) = delete; + const Stream& + operator=(const Stream&) = delete; + Stream(Stream&&) = delete; + Stream& + operator=(Stream&&) = delete; + + Connection& conn; + + // Callback(s) to invoke once we have the requested amount of space available in the buffer. + std::queue unblocked_callbacks; + void + handle_unblocked(); // Processes the above if space is available + + // Called to advance the number of acknowledged bytes (freeing up that space in the buffer for + // appending data). + void + acknowledge(size_t bytes); + + // Returns a view into unwritten stream data. This returns a vector of string_views of the data + // to write, in order. After writing any of the provided data you must call `wrote()` to signal + // how much of the given data was consumed (to advance the next pending() call). + std::vector + pending(); + + // Called to signal that bytes have been written and should now be considered sent (but still + // unacknowledged), thereby advancing the initial data position returned by the next `pending()` + // call. Should typically be called after `pending()` to signal how much of the pending data + // was actually used. + void + wrote(size_t bytes); + + // ngtcp2 stream_id, assigned during stream creation + StreamID stream_id{-1}; + + // ring buffer of outgoing stream data that has not yet been acknowledged. This cannot be + // resized once used as ngtcp2 will have pointers into the data. If this is empty then we are + // in user-provided buffer mode. + std::vector buffer{65536}; + + // user-provided buffers; only used when `buffer` is empty (via a `set_buffer_size(0)` or a 0 + // size given in the constructor). + std::deque, size_t>> user_buffers; + + // Offset of the first used byte in the circular buffer, will always be in [0, buffer.size()). + // For user-provided buffers this is the starting offset in the currently sending user-provided + // buffer. + size_t start{0}; + + // Number of sent-but-unacked packets in the buffer (i.e. [start, start+unacked_size) are sent + // but not yet acked). + size_t unacked_size{0}; + + // Number of used bytes in the buffer; thus start+size is the next write location and + // [start+unacked_size, start+size) is the range of not-yet-sent bytes. (Note that this + // description is ignoring the circularity of the buffer). + size_t size{0}; + + bool is_new{true}; + bool is_closing{false}; + bool sent_fin{false}; + bool is_shutdown{false}; + + // Async trigger we use to schedule when_available callbacks (so that we can make them happen in + // batches rather than after each and every packet ack). + std::shared_ptr avail_trigger; + + std::variant, std::weak_ptr> user_data; + }; + +} // namespace llarp::quic diff --git a/llarp/quic/tunnel.cpp b/llarp/quic/tunnel.cpp new file mode 100644 index 000000000..072686913 --- /dev/null +++ b/llarp/quic/tunnel.cpp @@ -0,0 +1,111 @@ +#include "tunnel.hpp" +#include "log.hpp" +#include "stream.hpp" + +namespace llarp::quic::tunnel +{ + // Takes data from the tcp connection and pushes it down the quic tunnel + void + on_outgoing_data(uvw::DataEvent& event, uvw::TCPHandle& client) + { + auto stream = client.data(); + assert(stream); + std::string_view data{event.data.get(), event.length}; + auto peer = client.peer(); + llarp::quic::Debug(peer.ip, ":", peer.port, " → lokinet ", llarp::quic::buffer_printer{data}); + // Steal the buffer from the DataEvent's unique_ptr: + stream->append_buffer(reinterpret_cast(event.data.release()), event.length); + if (stream->used() >= PAUSE_SIZE) + { + llarp::quic::Debug( + "quic tunnel is congested (have ", + stream->used(), + " bytes in flight); pausing local tcp connection reads"); + client.stop(); + stream->when_available([](llarp::quic::Stream& s) { + auto client = s.data(); + if (s.used() < PAUSE_SIZE) + { + llarp::quic::Debug("quic tunnel is no longer congested; resuming tcp connection reading"); + client->read(); + return true; + } + return false; + }); + } + else + { + llarp::quic::Debug("Queued ", event.length, " bytes"); + } + } + + // Received data from the quic tunnel and sends it to the TCP connection + void + on_incoming_data(llarp::quic::Stream& stream, llarp::quic::bstring_view bdata) + { + auto tcp = stream.data(); + assert(tcp); + std::string_view data{reinterpret_cast(bdata.data()), bdata.size()}; + auto peer = tcp->peer(); + llarp::quic::Debug(peer.ip, ":", peer.port, " ← lokinet ", llarp::quic::buffer_printer{data}); + + if (data.empty()) + return; + + // Try first to write immediately from the existing buffer to avoid needing an + // allocation and copy: + auto written = tcp->tryWrite(const_cast(data.data()), data.size()); + if (written < (int)data.size()) + { + data.remove_prefix(written); + + auto wdata = std::make_unique(data.size()); + std::copy(data.begin(), data.end(), wdata.get()); + tcp->write(std::move(wdata), data.size()); + } + } + + void + install_stream_forwarding(uvw::TCPHandle& tcp, llarp::quic::Stream& stream) + { + tcp.data(stream.shared_from_this()); + stream.weak_data(tcp.weak_from_this()); + + tcp.on([](auto&, uvw::TCPHandle& c) { + // This fires sometime after we call `close()` to signal that the close is done. + llarp::quic::Error( + "Connection with ", + c.peer().ip, + ":", + c.peer().port, + " closed directly, closing quic stream"); + c.data()->close(); + }); + tcp.on([](auto&, uvw::TCPHandle& c) { + // This fires on eof, most likely because the other side of the TCP connection closed it. + llarp::quic::Error( + "EOF on connection with ", c.peer().ip, ":", c.peer().port, ", closing quic stream"); + c.data()->close(); + }); + tcp.on([](const uvw::ErrorEvent& e, uvw::TCPHandle& tcp) { + llarp::quic::Error( + "ErrorEvent[", + e.name(), + ": ", + e.what(), + "] on connection with ", + tcp.peer().ip, + ":", + tcp.peer().port, + ", shutting down quic stream"); + // Failed to open connection, so close the quic stream + auto stream = tcp.data(); + if (stream) + stream->close(ERROR_TCP); + tcp.close(); + }); + tcp.on(tunnel::on_outgoing_data); + stream.data_callback = on_incoming_data; + } + +} // namespace tunnel diff --git a/llarp/quic/tunnel.hpp b/llarp/quic/tunnel.hpp new file mode 100644 index 000000000..c712b70df --- /dev/null +++ b/llarp/quic/tunnel.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "stream.hpp" +#include "log.hpp" + +#include +#include +#include +#include + +#include + +namespace llarp::quic::tunnel +{ + // The server sends back a 0x00 to signal that the remote TCP connection was established and that + // it is now accepting stream data; the client is not allowed to send any other data down the + // stream until this comes back (any data sent down the stream before then is discarded.) + inline constexpr std::byte CONNECT_INIT{0x00}; + // QUIC application error codes we sent on failures: + // Failure to establish an initial connection: + inline constexpr uint64_t ERROR_CONNECT{0x5471907}; + // Error if we receive something other than CONNECT_INIT as the initial stream data from the + // server + inline constexpr uint64_t ERROR_BAD_INIT{0x5471908}; + // Close error code sent if we get an error on the TCP socket (other than an initial connect + // failure) + inline constexpr uint64_t ERROR_TCP{0x5471909}; + + // We pause reading from the local TCP socket if we have more than this amount of outstanding + // unacked data in the quic tunnel, then resume once it drops below this. + inline constexpr size_t PAUSE_SIZE = 64 * 1024; + + // Callbacks for network events. The uvw::TCPHandle client must contain a shared pointer to the + // associated llarp::quic::Stream in its data, and the llarp::quic::Stream must contain a weak + // pointer to the uvw::TCPHandle. + + // Callback when we receive data to go out over lokinet, i.e. read from the local TCP socket + void + on_outgoing_data(uvw::DataEvent& event, uvw::TCPHandle& client); + + // Callback when we receive data from lokinet to write to the local TCP socket + void + on_incoming_data(llarp::quic::Stream& stream, llarp::quic::bstring_view bdata); + + // Callback to handle and discard the first incoming 0x00 byte that initiates the stream + void + on_init_incoming_data(llarp::quic::Stream& stream, llarp::quic::bstring_view bdata); + + // Creates a new tcp handle that forwards incoming data/errors/closes into appropriate actions on + // the given quic stream. + void + install_stream_forwarding(uvw::TCPHandle& tcp, llarp::quic::Stream& stream); + +} // namespace llarp::quic::tunnel diff --git a/llarp/quic/tunnel_client.cpp b/llarp/quic/tunnel_client.cpp new file mode 100644 index 000000000..29c98280f --- /dev/null +++ b/llarp/quic/tunnel_client.cpp @@ -0,0 +1,139 @@ +#include "connection.hpp" +#include "client.hpp" +#include "log.hpp" +#include "stream.hpp" +#include "tunnel.hpp" + +#include + +#include +#include + +#include + +using namespace std::literals; + +namespace llarp::quic::tunnel +{ + // When we receive a new incoming connection we immediately initiate a new quic stream. This quic + // stream in turn causes the other end to initiate a TCP connection on whatever port we specified + // in the connection; if the connection is established, it sends back a single byte 0x00 + // (CONNECT_INIT); otherwise it shuts down the stream with an error code. + void + on_new_connection(const uvw::ListenEvent&, uvw::TCPHandle& server) + { + llarp::quic::Debug("New connection!\n"); + auto client = server.loop().resource(); + server.accept(*client); + + auto conn = server.data(); + std::shared_ptr stream; + try + { + llarp::quic::Debug("open stream"); + stream = conn->open_stream( + [client](llarp::quic::Stream& stream, llarp::quic::bstring_view bdata) { + if (bdata.empty()) + return; + if (auto b0 = bdata[0]; b0 == tunnel::CONNECT_INIT) + { + // Set up callbacks, which replaces both of these initial callbacks + client->read(); + tunnel::install_stream_forwarding(*client, stream); + + if (bdata.size() > 1) + { + bdata.remove_prefix(1); + stream.data_callback(stream, std::move(bdata)); + } + llarp::quic::Debug("starting client reading"); + } + else + { + llarp::quic::Warn( + "Remote connection returned invalid initial byte (0x", + oxenmq::to_hex(bdata.begin(), bdata.begin() + 1), + "); dropping connection"); + client->closeReset(); + stream.close(tunnel::ERROR_BAD_INIT); + } + stream.io_ready(); + }, + [client](llarp::quic::Stream&, std::optional error_code) mutable { + if (error_code && *error_code == tunnel::ERROR_CONNECT) + llarp::quic::Debug("Remote TCP connection failed, closing local connection"); + else + llarp::quic::Warn( + "Stream connection closed ", + error_code ? "with error " + std::to_string(*error_code) : "gracefully", + "; closing local TCP connection."); + auto peer = client->peer(); + llarp::quic::Debug("Closing connection to ", peer.ip, ":", peer.port); + if (error_code) + client->closeReset(); + else + client->close(); + }); + stream->io_ready(); + } + catch (const std::exception& e) + { + llarp::quic::Debug("open stream failed"); + client->closeReset(); + return; + } + + llarp::quic::Debug("setup stream"); + conn->io_ready(); + } + + int + usage(std::string_view arg0, std::string_view msg) + { + std::cerr << msg << "\n\n" + << "Usage: " << arg0 + << " [DESTPORT [SERVERPORT [LISTENPORT]]]\n\nDefaults to ports 4444 4242 5555\n"; + return 1; + } + + int + main(int argc, char* argv[]) + { + auto loop = uvw::Loop::create(); + + std::array ports{{4444, 4242, 5555}}; + for (size_t i = 0; i < ports.size(); i++) + { + if (argc < 2 + (int)i) + break; + if (!parse_int(argv[1 + i], ports[i])) + return usage(argv[0], "Invalid port "s + argv[1 + i]); + } + auto& [dest_port, server_port, listen_port] = ports; + std::cout << "Connecting to quic server at localhost:" << server_port + << " to reach tunneled port " << dest_port + << ", listening on localhost:" << listen_port << "\n"; + + signal(SIGPIPE, SIG_IGN); + + llarp::quic::Debug("Initializing client"); + auto tunnel_client = std::make_shared( + llarp::quic::Address{{127, 0, 0, 1}, server_port}, // server addr + loop, + dest_port // tunnel destination port + ); + tunnel_client->default_stream_buffer_size = 0; // We steal uvw's provided buffers + llarp::quic::Debug("Initialized client"); + + // Start listening for TCP connections: + auto server = loop->resource(); + server->data(tunnel_client->get_connection()); + server->on(llarp::quic::tunnel::on_new_connection); + + server->bind("127.0.0.1", listen_port); + server->listen(); + + loop->run(); + } + +} // namespace llarp::quic::tunnel diff --git a/llarp/quic/tunnel_server.cpp b/llarp/quic/tunnel_server.cpp new file mode 100644 index 000000000..3bd2493b6 --- /dev/null +++ b/llarp/quic/tunnel_server.cpp @@ -0,0 +1,174 @@ +#include "tunnel_server.hpp" +#include "tunnel.hpp" +#include "connection.hpp" +#include "server.hpp" +#include "log.hpp" + +#include + +#include + +using namespace std::literals; + +namespace llarp::quic::tunnel +{ + IncomingTunnel::IncomingTunnel(uint16_t localhost_port) + : IncomingTunnel{ + [localhost_port]( + [[maybe_unused]] const auto& remote, uint16_t port, SockAddr& connect_to) { + if (port != localhost_port) + return AcceptResult::DECLINE; + connect_to.setIPv4(127, 0, 0, 1); + connect_to.setPort(port); + return AcceptResult::ACCEPT; + }} + {} + + int + usage(std::string_view arg0, std::string_view msg) + { + std::cerr << msg << "\n\n" + << "Usage: " << arg0 + << " [LISTENPORT [ALLOWED ...]]\n\nDefaults to listening on 4242 and allowing " + "22,80,4444,8080\n"; + return 1; + } + + int + main(int argc, char* argv[]) + { + uint16_t listen_port = 4242; + std::set allowed_ports{{22, 80, 4444, 8080}}; + + if (argc >= 2 && !parse_int(argv[1], listen_port)) + return usage(argv[0], "Invalid port "s + argv[1]); + if (argc >= 3) + { + allowed_ports.clear(); + for (int i = 2; i < argc; i++) + { + if (argv[i] == "all"sv) + { + allowed_ports.clear(); + break; + } + uint16_t port; + if (!parse_int(argv[i], port)) + return usage(argv[0], "Invalid port "s + argv[i]); + allowed_ports.insert(port); + } + } + + auto loop = uvw::Loop::create(); + + Address listen_addr{{0, 0, 0, 0}, listen_port}; + + signal(SIGPIPE, SIG_IGN); + + // The local address we connect to for incoming connections. (localhost for this demo, should + // be the localhost.loki address for lokinet). + std::string localhost = "127.0.0.1"; + + llarp::quic::Debug("Initializing server"); + llarp::quic::Server s{ + listen_addr, + loop, + [loop, localhost, allowed_ports]( + llarp::quic::Server&, llarp::quic::Stream& stream, uint16_t port) { + llarp::quic::Debug( + "\e[33mNew incoming quic stream ", + stream.id(), + " to reach ", + localhost, + ":", + port, + "\e[0m"); + if (port == 0 || !(allowed_ports.empty() || allowed_ports.count(port))) + { + llarp::quic::Warn( + "quic stream denied by configuration: ", port, " is not a permitted local port"); + return false; + } + /* + stream.data_callback = [init_seen=false](llarp::quic::Stream& stream, + llarp::quic::bstring_view bdata) mutable { if (init_seen) { llarp::quic::Warn("Invalid + remote data: received multiple bytes before connection confirmation"); + } + }; + */ + stream.close_callback = [](llarp::quic::Stream& strm, + std::optional error_code) { + llarp::quic::Debug( + error_code ? "Remote side" : "We", + " closed the quic stream, closing localhost tcp connection"); + if (error_code && *error_code > 0) + llarp::quic::Warn("Remote quic stream was closed with error code ", *error_code); + auto tcp = strm.data(); + if (!tcp) + llarp::quic::Debug("Local TCP connection already closed"); + else + tcp->close(); + }; + // Try to open a TCP connection to the configured localhost port; if we establish a + // connection then we immediately send a CONNECT_INIT back down the stream; if we fail + // then we send a fail-to-connect error code. Once we successfully connect both of + // these handlers get replaced with the normal tunnel handlers. + auto tcp = loop->resource(); + auto error_handler = tcp->once( + [&stream, localhost, port](const uvw::ErrorEvent&, uvw::TCPHandle&) { + llarp::quic::Error( + "Failed to connect to ", localhost, ":", port, ", shutting down quic stream"); + stream.close(tunnel::ERROR_CONNECT); + }); + tcp->once( + [streamw = stream.weak_from_this(), error_handler = std::move(error_handler)]( + const uvw::ConnectEvent&, uvw::TCPHandle& tcp) { + auto peer = tcp.peer(); + auto stream = streamw.lock(); + if (!stream) + { + llarp::quic::Warn( + "Connected to ", + peer.ip, + ":", + peer.port, + " but quic stream has gone away; resetting local connection"); + tcp.closeReset(); + return; + } + llarp::quic::Debug( + "\e[32mConnected to ", + peer.ip, + ":", + peer.port, + " for quic ", + stream->id(), + "\e[0m"); + tcp.erase(error_handler); + tunnel::install_stream_forwarding(tcp, *stream); + assert(stream->used() == 0); + + stream->append_buffer(new std::byte[1]{tunnel::CONNECT_INIT}, 1); + tcp.read(); + }); + + tcp->connect("127.0.0.1", port); + + return true; + }}; + s.default_stream_buffer_size = 0; // We steal uvw's provided buffers + llarp::quic::Debug("Initialized server"); + std::cout << "Listening on localhost:" << listen_port + << " with tunnel(s) to localhost port(s):"; + if (allowed_ports.empty()) + std::cout << " (any)"; + for (auto p : allowed_ports) + std::cout << ' ' << p; + std::cout << '\n'; + + loop->run(); + + return 0; + } + +} // namespace llarp::quic::tunnel diff --git a/llarp/quic/tunnel_server.hpp b/llarp/quic/tunnel_server.hpp new file mode 100644 index 000000000..2a2ab78d4 --- /dev/null +++ b/llarp/quic/tunnel_server.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +#include + +namespace llarp::quic::tunnel +{ + enum class AcceptResult : int + { + ACCEPT = 0, // Accepts a connection + DECLINE = -1, // Declines a connection (try other callbacks, refuse if all decline) + REFUSE = -2, // Refuses a connection (don't try any more callbacks) + }; + + // Class that wraps an incoming connection acceptance callback (to allow for callback removal). + // This is not directly constructible: you must construct it via the TunnelServer instance. + class IncomingTunnel final + { + public: + using AcceptCallback = std::function; + + private: + AcceptCallback accept; + + friend class TunnelServer; + + // Constructor with a full callback; invoked via TunnelServer::add_incoming_tunnel + explicit IncomingTunnel(AcceptCallback accept) : accept{std::move(accept)} + {} + + // Constructor for a simple forwarding to a single localhost port. E.g. IncomingTunnel(22) + // allows incoming connections to reach port 22 and forwards them to localhost:22. + explicit IncomingTunnel(uint16_t localhost_port); + + // Constructor for forwarding everything to the same port; this is used by full clients by + // default. + IncomingTunnel(); + }; + + // Class that handles incoming quic connections. This class sets itself up in the llarp event + // loop on construction and maintains a list of incoming acceptor callbacks. When a new incoming + // quic connections is being established we try the callbacks one by one to determine the local + // TCP port the tunnel should be connected to until: + // - a callback sets connect_to and returns AcceptResult::ACCEPT - we connect it to the returned + // address + // - a callback returns AcceptResult::REFUSE - we reject the connection + // + // If a callback returns AcceptResult::DECLINE then we skip that callback and try the next one; if + // all callbacks decline (or we have no callbacks at all) then we reject the connection. + // + // Note that tunnel operations and initialization are done in the event loop thread and so will + // not take effect until the next event loop tick when called from some other thread. + class TunnelServer : public std::enable_shared_from_this + { + public: + explicit TunnelServer(EventLoop_ptr ev); + + // Appends a new tunnel to the end of the queue; all arguments are forwarded to private + // constructor(s) of IncomingTunnel. + template + std::shared_ptr + add_incoming_tunnel(Args&&... args) + { + return std::shared_ptr{new IncomingTunnel{std::forward(args)...}}; + } + + // Removes a tunnel acceptor from the acceptor queue. + void + remove_incoming_tunnel(std::weak_ptr tunnel); + + private: + EventLoop_ptr ev; + std::vector> tunnels; + }; + +} // namespace llarp::quic::tunnel diff --git a/llarp/service/protocol_type.hpp b/llarp/service/protocol_type.hpp index 37521c520..756c04d18 100644 --- a/llarp/service/protocol_type.hpp +++ b/llarp/service/protocol_type.hpp @@ -13,5 +13,6 @@ namespace llarp::service TrafficV6 = 2UL, Exit = 3UL, Auth = 4UL, + QUIC = 5UL, }; } // namespace llarp::service