mirror of
https://github.com/oxen-io/lokinet.git
synced 2024-11-11 07:10:36 +00:00
1231 lines
38 KiB
C++
1231 lines
38 KiB
C++
#include "connection.hpp"
|
|
#include "client.hpp"
|
|
#include "server.hpp"
|
|
#include <llarp/util/logging/logger.hpp>
|
|
#include <llarp/util/logging/buffer.hpp>
|
|
|
|
#include <cassert>
|
|
#include <charconv>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
|
|
#include <uvw/async.h>
|
|
#include <uvw/poll.h>
|
|
#include <uvw/timer.h>
|
|
|
|
#include <iterator>
|
|
#include <oxenmq/hex.h>
|
|
#include <oxenmq/bt_serialize.h>
|
|
|
|
extern "C"
|
|
{
|
|
#include <sodium/randombytes.h>
|
|
}
|
|
|
|
namespace llarp::quic
|
|
{
|
|
ConnectionID::ConnectionID(const uint8_t* cid, size_t length)
|
|
{
|
|
assert(length <= max_size());
|
|
datalen = length;
|
|
std::memmove(data, cid, datalen);
|
|
}
|
|
|
|
std::ostream&
|
|
operator<<(std::ostream& o, const ConnectionID& c)
|
|
{
|
|
return o << oxenmq::to_hex(c.data, c.data + c.datalen);
|
|
}
|
|
|
|
ConnectionID
|
|
ConnectionID::random(size_t size)
|
|
{
|
|
ConnectionID r;
|
|
r.datalen = std::min(size, ConnectionID::max_size());
|
|
randombytes_buf(r.data, r.datalen);
|
|
return r;
|
|
}
|
|
|
|
namespace
|
|
{
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
|
|
|
constexpr int FAIL = NGTCP2_ERR_CALLBACK_FAILURE;
|
|
|
|
int
|
|
client_initial(ngtcp2_conn* conn_, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
|
|
// Initialization the connection and send our transport parameters to the server. This will
|
|
// put the connection into NGTCP2_CS_CLIENT_WAIT_HANDSHAKE state.
|
|
return static_cast<Connection*>(user_data)->init_client();
|
|
}
|
|
int
|
|
recv_client_initial(ngtcp2_conn* conn_, const ngtcp2_cid* dcid, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
|
|
// New incoming connection from a client: our server connection starts out here in state
|
|
// NGTCP2_CS_SERVER_INITIAL, but we should immediately get into recv_crypto_data because the
|
|
// initial client packet should contain the client's transport parameters.
|
|
|
|
auto& conn = *static_cast<Connection*>(user_data);
|
|
assert(conn_ == conn.conn.get());
|
|
|
|
if (0 != conn.setup_server_crypto_initial())
|
|
return FAIL;
|
|
|
|
return 0;
|
|
}
|
|
int
|
|
recv_crypto_data(
|
|
ngtcp2_conn* conn_,
|
|
ngtcp2_crypto_level crypto_level,
|
|
uint64_t offset,
|
|
const uint8_t* rawdata,
|
|
size_t rawdatalen,
|
|
void* user_data)
|
|
{
|
|
std::basic_string_view data{rawdata, rawdatalen};
|
|
LogTrace("Receiving crypto data @ level ", crypto_level, " ", buffer_printer{data});
|
|
|
|
auto& conn = *static_cast<Connection*>(user_data);
|
|
switch (crypto_level)
|
|
{
|
|
case NGTCP2_CRYPTO_LEVEL_EARLY:
|
|
// We don't currently use or support 0rtt
|
|
LogWarn("Invalid EARLY crypto level");
|
|
return FAIL;
|
|
|
|
case NGTCP2_CRYPTO_LEVEL_INITIAL:
|
|
// "Initial" level means we are still handshaking; if we are server then we receive
|
|
// the client's transport params (sent in client_initial, above) and blast ours
|
|
// back. If we are a client then getting here means we received a response from the
|
|
// server, which is that returned server transport params.
|
|
|
|
if (auto rv = conn.recv_initial_crypto(data); rv != 0)
|
|
return rv;
|
|
|
|
if (ngtcp2_conn_is_server(conn))
|
|
{
|
|
if (auto rv = conn.send_magic(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0)
|
|
return rv;
|
|
if (auto rv = conn.send_transport_params(NGTCP2_CRYPTO_LEVEL_HANDSHAKE); rv != 0)
|
|
return rv;
|
|
}
|
|
|
|
break;
|
|
|
|
case NGTCP2_CRYPTO_LEVEL_HANDSHAKE:
|
|
if (!ngtcp2_conn_is_server(conn))
|
|
{
|
|
if (auto rv = conn.recv_transport_params(data); rv != 0)
|
|
return rv;
|
|
// At this stage of the protocol with TLS the client sends back TLS info so that
|
|
// the server can install our rx key; we have to send *something* back to invoke
|
|
// the server's HANDSHAKE callback (so that it knows handshake is complete) so
|
|
// send the magic again.
|
|
if (auto rv = conn.send_magic(NGTCP2_CRYPTO_LEVEL_HANDSHAKE); rv != 0)
|
|
return rv;
|
|
}
|
|
else
|
|
{
|
|
// Check that we received the above as expected
|
|
if (data != handshake_magic)
|
|
{
|
|
LogWarn("Invalid handshake crypto frame from client: did not find expected magic");
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
}
|
|
|
|
conn.complete_handshake();
|
|
break;
|
|
|
|
case NGTCP2_CRYPTO_LEVEL_APPLICATION:
|
|
// if (!conn.init_tx_key())
|
|
// return FAIL;
|
|
break;
|
|
|
|
default:
|
|
LogWarn("Unhandled crypto_level ", crypto_level);
|
|
return FAIL;
|
|
}
|
|
conn.io_ready();
|
|
return 0;
|
|
}
|
|
int
|
|
encrypt(
|
|
uint8_t* dest,
|
|
const ngtcp2_crypto_aead* aead,
|
|
const ngtcp2_crypto_aead_ctx* aead_ctx,
|
|
const uint8_t* plaintext,
|
|
size_t plaintextlen,
|
|
const uint8_t* nonce,
|
|
size_t noncelen,
|
|
const uint8_t* ad,
|
|
size_t adlen)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
LogTrace("Lengths: ", plaintextlen, "+", noncelen, "+", adlen);
|
|
if (dest != plaintext)
|
|
std::memmove(dest, plaintext, plaintextlen);
|
|
return 0;
|
|
}
|
|
int
|
|
decrypt(
|
|
uint8_t* dest,
|
|
const ngtcp2_crypto_aead* aead,
|
|
const ngtcp2_crypto_aead_ctx* aead_ctx,
|
|
const uint8_t* ciphertext,
|
|
size_t ciphertextlen,
|
|
const uint8_t* nonce,
|
|
size_t noncelen,
|
|
const uint8_t* ad,
|
|
size_t adlen)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
LogTrace("Lengths: ", ciphertextlen, "+", noncelen, "+", adlen);
|
|
if (dest != ciphertext)
|
|
std::memmove(dest, ciphertext, ciphertextlen);
|
|
return 0;
|
|
}
|
|
int
|
|
hp_mask(
|
|
uint8_t* dest,
|
|
const ngtcp2_crypto_cipher* hp,
|
|
const ngtcp2_crypto_cipher_ctx* hp_ctx,
|
|
const uint8_t* sample)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
memset(dest, 0, NGTCP2_HP_MASKLEN);
|
|
return 0;
|
|
}
|
|
int
|
|
recv_stream_data(
|
|
ngtcp2_conn* conn,
|
|
uint32_t flags,
|
|
int64_t stream_id,
|
|
uint64_t offset,
|
|
const uint8_t* data,
|
|
size_t datalen,
|
|
void* user_data,
|
|
void* stream_user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
return static_cast<Connection*>(user_data)->stream_receive(
|
|
{stream_id},
|
|
{reinterpret_cast<const std::byte*>(data), datalen},
|
|
flags & NGTCP2_STREAM_DATA_FLAG_FIN);
|
|
}
|
|
|
|
int
|
|
acked_stream_data_offset(
|
|
ngtcp2_conn* conn_,
|
|
int64_t stream_id,
|
|
uint64_t offset,
|
|
uint64_t datalen,
|
|
void* user_data,
|
|
void* stream_user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
LogTrace("Ack [", offset, ",", offset + datalen, ")");
|
|
return static_cast<Connection*>(user_data)->stream_ack({stream_id}, datalen);
|
|
}
|
|
|
|
int
|
|
stream_open(ngtcp2_conn* conn, int64_t stream_id, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
return static_cast<Connection*>(user_data)->stream_opened({stream_id});
|
|
}
|
|
int
|
|
stream_reset_cb(
|
|
ngtcp2_conn* conn,
|
|
int64_t stream_id,
|
|
uint64_t final_size,
|
|
uint64_t app_error_code,
|
|
void* user_data,
|
|
void* stream_user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
return static_cast<Connection*>(user_data)->stream_reset({stream_id}, app_error_code);
|
|
}
|
|
|
|
// (client only)
|
|
int
|
|
recv_retry(ngtcp2_conn* conn, const ngtcp2_pkt_hd* hd, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
LogError("FIXME UNIMPLEMENTED ", __func__);
|
|
// FIXME
|
|
return 0;
|
|
}
|
|
int
|
|
extend_max_local_streams_bidi(ngtcp2_conn* conn_, uint64_t max_streams, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
auto& conn = *static_cast<Connection*>(user_data);
|
|
if (conn.on_stream_available)
|
|
if (uint64_t left = ngtcp2_conn_get_streams_bidi_left(conn); left > 0)
|
|
conn.on_stream_available(conn);
|
|
|
|
return 0;
|
|
}
|
|
|
|
int
|
|
rand(
|
|
uint8_t* dest,
|
|
size_t destlen,
|
|
const ngtcp2_rand_ctx* rand_ctx,
|
|
[[maybe_unused]] ngtcp2_rand_usage usage)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
randombytes_buf(dest, destlen);
|
|
return 0;
|
|
}
|
|
int
|
|
get_new_connection_id(
|
|
ngtcp2_conn* conn_, ngtcp2_cid* cid_, uint8_t* token, size_t cidlen, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
|
|
auto& conn = *static_cast<Connection*>(user_data);
|
|
auto cid = conn.make_alias_id(cidlen);
|
|
assert(cid.datalen == cidlen);
|
|
*cid_ = cid;
|
|
|
|
conn.endpoint.make_stateless_reset_token(cid, token);
|
|
LogDebug(
|
|
"make stateless reset token ",
|
|
oxenmq::to_hex(token, token + NGTCP2_STATELESS_RESET_TOKENLEN));
|
|
|
|
return 0;
|
|
}
|
|
int
|
|
remove_connection_id(ngtcp2_conn* conn, const ngtcp2_cid* cid, void* user_data)
|
|
{
|
|
LogTrace("######################", __func__);
|
|
LogError("FIXME UNIMPLEMENTED ", __func__);
|
|
// FIXME
|
|
return 0;
|
|
}
|
|
int
|
|
update_key(
|
|
ngtcp2_conn* conn,
|
|
uint8_t* rx_secret,
|
|
uint8_t* tx_secret,
|
|
ngtcp2_crypto_aead_ctx* rx_aead_ctx,
|
|
uint8_t* rx_iv,
|
|
ngtcp2_crypto_aead_ctx* tx_aead_ctx,
|
|
uint8_t* tx_iv,
|
|
const uint8_t* current_rx_secret,
|
|
const uint8_t* current_tx_secret,
|
|
size_t secretlen,
|
|
void* user_data)
|
|
{
|
|
// This is a no-op since we don't encrypt anything in the first place
|
|
return 0;
|
|
}
|
|
#pragma GCC diagnostic pop
|
|
} // namespace
|
|
|
|
#if 0
|
|
#ifndef NDEBUG
|
|
extern "C" inline void
|
|
ngtcp_trace_logger([[maybe_unused]] void* user_data, const char* fmt, ...)
|
|
{
|
|
va_list ap;
|
|
va_start(ap, fmt);
|
|
if (char* msg; vasprintf(&msg, sizeof(ngtcp_debug_out), fmt, ap) >= 0)
|
|
{
|
|
LogTraceExplicit("external/ngtcp2/*.c", 0, msg);
|
|
std::free(msg);
|
|
}
|
|
va_end(ap);
|
|
}
|
|
#endif
|
|
#endif
|
|
|
|
io_result
|
|
Connection::send()
|
|
{
|
|
assert(send_buffer_size <= send_buffer.size());
|
|
io_result rv{};
|
|
bstring_view send_data{send_buffer.data(), send_buffer_size};
|
|
|
|
if (!send_data.empty())
|
|
{
|
|
LogDebug("Sending packet: ", buffer_printer{send_data});
|
|
rv = endpoint.send_packet(path.remote, send_data, send_pkt_info.ecn);
|
|
}
|
|
return rv;
|
|
// We succeeded
|
|
//
|
|
// FIXME2: probably don't want to do these things *here*, because this is called from the stream
|
|
// checking code.
|
|
//
|
|
// FIXME: check and send other pending streams
|
|
//
|
|
// FIXME: schedule retransmit?
|
|
// return true;
|
|
}
|
|
|
|
std::tuple<ngtcp2_settings, ngtcp2_transport_params, ngtcp2_callbacks>
|
|
Connection::init()
|
|
{
|
|
auto loop = endpoint.get_loop();
|
|
io_trigger = loop->resource<uvw::AsyncHandle>();
|
|
io_trigger->on<uvw::AsyncEvent>([this](auto&, auto&) { on_io_ready(); });
|
|
|
|
retransmit_timer = loop->resource<uvw::TimerHandle>();
|
|
retransmit_timer->on<uvw::TimerEvent>([this](auto&, auto&) {
|
|
LogTrace("Retransmit timer fired!");
|
|
if (auto rv = ngtcp2_conn_handle_expiry(*this, get_timestamp()); rv != 0)
|
|
{
|
|
LogWarn("expiry handler invocation returned an error: ", ngtcp2_strerror(rv));
|
|
endpoint.close_connection(*this, ngtcp2_err_infer_quic_transport_error_code(rv), false);
|
|
}
|
|
else
|
|
{
|
|
flush_streams();
|
|
}
|
|
});
|
|
retransmit_timer->start(0ms, 0ms);
|
|
|
|
auto result = std::tuple<ngtcp2_settings, ngtcp2_transport_params, ngtcp2_callbacks>{};
|
|
auto& [settings, tparams, cb] = result;
|
|
cb.recv_crypto_data = recv_crypto_data;
|
|
cb.encrypt = encrypt;
|
|
cb.decrypt = decrypt;
|
|
cb.hp_mask = hp_mask;
|
|
cb.recv_stream_data = recv_stream_data;
|
|
cb.acked_stream_data_offset = acked_stream_data_offset;
|
|
cb.stream_open = stream_open;
|
|
cb.stream_reset = stream_reset_cb;
|
|
cb.extend_max_local_streams_bidi = extend_max_local_streams_bidi;
|
|
cb.rand = rand;
|
|
cb.get_new_connection_id = get_new_connection_id;
|
|
cb.remove_connection_id = remove_connection_id;
|
|
cb.update_key = update_key;
|
|
|
|
ngtcp2_settings_default(&settings);
|
|
#if 0
|
|
#ifndef NDEBUG
|
|
settings.log_printf = ngtcp_trace_logger;
|
|
#endif
|
|
#endif
|
|
settings.initial_ts = get_timestamp();
|
|
// FIXME: IPv6
|
|
settings.max_udp_payload_size = NGTCP2_MAX_PKTLEN_IPV4;
|
|
settings.cc_algo = NGTCP2_CC_ALGO_CUBIC;
|
|
// settings.initial_rtt = ???; # NGTCP2's default is 333ms
|
|
|
|
ngtcp2_transport_params_default(&tparams);
|
|
|
|
// Connection level flow control window:
|
|
tparams.initial_max_data = CONNECTION_BUFFER;
|
|
// Max send buffer for a streams (local is for streams we initiate, remote is for replying on
|
|
// streams they initiate to us):
|
|
tparams.initial_max_stream_data_bidi_local = STREAM_BUFFER;
|
|
tparams.initial_max_stream_data_bidi_remote = STREAM_BUFFER;
|
|
// Max *cumulative* streams we support on a connection:
|
|
tparams.initial_max_streams_bidi = STREAM_LIMIT;
|
|
tparams.initial_max_streams_uni = 0;
|
|
tparams.max_idle_timeout = std::chrono::nanoseconds(IDLE_TIMEOUT).count();
|
|
tparams.active_connection_id_limit = 8;
|
|
|
|
LogDebug("Done basic connection initialization");
|
|
|
|
return result;
|
|
}
|
|
|
|
Connection::Connection(
|
|
Server& s, const ConnectionID& base_cid_, ngtcp2_pkt_hd& header, const Path& path)
|
|
: endpoint{s}, base_cid{base_cid_}, dest_cid{header.scid}, path{path}
|
|
{
|
|
auto [settings, tparams, cb] = init();
|
|
|
|
cb.recv_client_initial = recv_client_initial;
|
|
|
|
LogDebug("header.type = ", header.type);
|
|
|
|
// ConnectionIDs are a little complicated:
|
|
// - when a client creates a new connection to us, it creates a random source connection ID
|
|
// *and* a random destination connection id. The server won't have that connection ID, of
|
|
// course, but we use it to recognize that we should try accepting it as a new connection.
|
|
// - When we talk to the client we use the random source connection ID that it generated as our
|
|
// destination connection ID.
|
|
// - We choose our own source ID, however: we *don't* use the random one the client picked for
|
|
// us. Instead we generate a random one and sent it back as *our* source connection ID in the
|
|
// reply to the client.
|
|
// - the client still needs to match up that reply with that request, and so we include the
|
|
// destination connection ID that the client generated for us in the transport parameters as
|
|
// the original_dcid: this lets the client match up the request, after which it can't promptly
|
|
// forget about it and start using the source CID that we gave it.
|
|
//
|
|
// So, in other words, the conversation goes like this:
|
|
// - Client: [SCID:clientid, DCID:randomid, TRANSPORT_PARAMS]
|
|
// - Server: [SCID:serverid, DCID:clientid TRANSPORT_PARAMS(origid=randomid)]
|
|
//
|
|
// - For the client, .base_cid={clientid} and .dest_cid={randomid} initially but gets updated to
|
|
// .dest_cid={serverid} when we hear back from the server.
|
|
// - For the server, .base_cid={serverid} and .dest_cid={clientid}
|
|
|
|
tparams.original_dcid = header.dcid;
|
|
|
|
LogDebug("original_dcid is now set to ", ConnectionID(tparams.original_dcid));
|
|
|
|
settings.token = header.token;
|
|
|
|
// FIXME is this required?
|
|
randombytes_buf(tparams.stateless_reset_token, sizeof(tparams.stateless_reset_token));
|
|
tparams.stateless_reset_token_present = 1;
|
|
|
|
ngtcp2_conn* connptr;
|
|
LogDebug("server_new, path=", path);
|
|
if (auto rv = ngtcp2_conn_server_new(
|
|
&connptr,
|
|
&dest_cid,
|
|
&base_cid,
|
|
path,
|
|
header.version,
|
|
&cb,
|
|
&settings,
|
|
&tparams,
|
|
nullptr /*default mem allocator*/,
|
|
this);
|
|
rv != 0)
|
|
throw std::runtime_error{"Failed to initialize server connection: "s + ngtcp2_strerror(rv)};
|
|
conn.reset(connptr);
|
|
|
|
LogDebug("Created new server conn ", base_cid);
|
|
}
|
|
|
|
Connection::Connection(
|
|
Client& c, const ConnectionID& scid, const Path& path, uint16_t tunnel_port)
|
|
: tunnel_port{tunnel_port}
|
|
, endpoint{c}
|
|
, base_cid{scid}
|
|
, dest_cid{ConnectionID::random()}
|
|
, path{path}
|
|
{
|
|
auto [settings, tparams, cb] = init();
|
|
|
|
assert(tunnel_port != 0);
|
|
|
|
cb.client_initial = client_initial;
|
|
cb.recv_retry = recv_retry;
|
|
// cb.extend_max_local_streams_bidi = extend_max_local_streams_bidi;
|
|
// cb.recv_new_token = recv_new_token;
|
|
|
|
ngtcp2_conn* connptr;
|
|
|
|
if (auto rv = ngtcp2_conn_client_new(
|
|
&connptr,
|
|
&dest_cid,
|
|
&scid,
|
|
path,
|
|
NGTCP2_PROTO_VER_V1,
|
|
&cb,
|
|
&settings,
|
|
&tparams,
|
|
nullptr,
|
|
this);
|
|
rv != 0)
|
|
throw std::runtime_error{"Failed to initialize client connection: "s + ngtcp2_strerror(rv)};
|
|
conn.reset(connptr);
|
|
|
|
LogDebug("Created new client conn ", scid);
|
|
}
|
|
|
|
Connection::~Connection()
|
|
{
|
|
if (io_trigger)
|
|
io_trigger->close();
|
|
}
|
|
|
|
void
|
|
Connection::io_ready()
|
|
{
|
|
io_trigger->send();
|
|
}
|
|
|
|
void
|
|
Connection::on_io_ready()
|
|
{
|
|
LogTrace(__func__);
|
|
flush_streams();
|
|
LogTrace("done ", __func__);
|
|
}
|
|
|
|
void
|
|
Connection::flush_streams()
|
|
{
|
|
// conn, path, pi, dest, destlen, and ts
|
|
std::optional<uint64_t> ts;
|
|
|
|
send_pkt_info = {};
|
|
|
|
auto add_stream_data =
|
|
[&](StreamID stream_id, const ngtcp2_vec* datav, size_t datalen, uint32_t flags = 0) {
|
|
std::array<ngtcp2_ssize, 2> result;
|
|
auto& [nwrite, consumed] = result;
|
|
if (!ts)
|
|
ts = get_timestamp();
|
|
|
|
LogTrace(
|
|
"send_buffer size=", send_buffer.size(), ", datalen=", datalen, ", flags=", flags);
|
|
nwrite = ngtcp2_conn_writev_stream(
|
|
conn.get(),
|
|
&path.path,
|
|
&send_pkt_info,
|
|
u8data(send_buffer),
|
|
send_buffer.size(),
|
|
&consumed,
|
|
NGTCP2_WRITE_STREAM_FLAG_MORE | flags,
|
|
stream_id.id,
|
|
datav,
|
|
datalen,
|
|
*ts);
|
|
return result;
|
|
};
|
|
|
|
auto send_packet = [&](auto nwrite) -> bool {
|
|
send_buffer_size = nwrite;
|
|
LogDebug("Sending ", send_buffer_size, "B packet");
|
|
|
|
// FIXME: update remote addr? ecn?
|
|
auto sent = send();
|
|
if (sent.blocked())
|
|
return false; // We'll get called again when the socket becomes writable
|
|
|
|
send_buffer_size = 0;
|
|
if (!sent)
|
|
{
|
|
LogWarn("I/O error while trying to send packet: ", sent.str());
|
|
// FIXME: disconnect?
|
|
return false;
|
|
}
|
|
LogDebug("packet away!");
|
|
return true;
|
|
};
|
|
|
|
std::list<Stream*> strs;
|
|
for (auto& [stream_id, stream_ptr] : streams)
|
|
if (stream_ptr)
|
|
strs.push_back(stream_ptr.get());
|
|
|
|
// Maximum number of stream data packets to send out at once; if we reach this then we'll
|
|
// schedule another event loop call of ourselves (so that we don't starve the loop).
|
|
constexpr int max_stream_packets = 15;
|
|
int stream_packets = 0;
|
|
while (!strs.empty() && stream_packets < max_stream_packets)
|
|
{
|
|
for (auto it = strs.begin(); it != strs.end();)
|
|
{
|
|
auto& stream = **it;
|
|
auto bufs = stream.pending();
|
|
if (stream.is_shutdown
|
|
|| (bufs.empty() && !stream.is_new && !(stream.is_closing && !stream.sent_fin)))
|
|
{
|
|
it = strs.erase(it);
|
|
continue;
|
|
}
|
|
std::vector<ngtcp2_vec> vecs;
|
|
vecs.reserve(bufs.size());
|
|
std::transform(bufs.begin(), bufs.end(), std::back_inserter(vecs), [](const auto& buf) {
|
|
return ngtcp2_vec{const_cast<uint8_t*>(u8data(buf)), buf.size()};
|
|
});
|
|
|
|
#ifndef NDEBUG
|
|
{
|
|
std::string buf_sizes;
|
|
for (auto& b : bufs)
|
|
{
|
|
if (!buf_sizes.empty())
|
|
buf_sizes += '+';
|
|
buf_sizes += std::to_string(b.size());
|
|
}
|
|
LogDebug("Sending ", buf_sizes.empty() ? "no" : buf_sizes, " data for ", stream.id());
|
|
}
|
|
#endif
|
|
|
|
uint32_t extra_flags = 0;
|
|
if (stream.is_closing && !stream.sent_fin)
|
|
{
|
|
LogDebug("Sending FIN");
|
|
extra_flags |= NGTCP2_WRITE_STREAM_FLAG_FIN;
|
|
stream.sent_fin = true;
|
|
}
|
|
else if (stream.is_new)
|
|
{
|
|
stream.is_new = false;
|
|
}
|
|
|
|
auto [nwrite, consumed] =
|
|
add_stream_data(stream.id(), vecs.data(), vecs.size(), extra_flags);
|
|
LogDebug(
|
|
"add_stream_data for stream ", stream.id(), " returned [", nwrite, ",", consumed, "]");
|
|
|
|
if (nwrite > 0)
|
|
{
|
|
if (consumed >= 0)
|
|
{
|
|
LogDebug("consumed ", consumed, " bytes from stream ", stream.id());
|
|
stream.wrote(consumed);
|
|
}
|
|
|
|
LogDebug("Sending stream data packet");
|
|
if (!send_packet(nwrite))
|
|
return;
|
|
++stream_packets;
|
|
++it;
|
|
continue;
|
|
}
|
|
|
|
switch (nwrite)
|
|
{
|
|
case 0:
|
|
LogDebug(
|
|
"Done stream writing to ",
|
|
stream.id(),
|
|
" (either stream is congested or we have nothing else to send right now)");
|
|
assert(consumed <= 0);
|
|
break;
|
|
case NGTCP2_ERR_WRITE_MORE:
|
|
LogDebug(
|
|
"consumed ", consumed, " bytes from stream ", stream.id(), " and have space left");
|
|
stream.wrote(consumed);
|
|
if (stream.unsent() > 0)
|
|
{
|
|
// We have more to send on this stream, so keep us in the queue
|
|
++it;
|
|
continue;
|
|
}
|
|
break;
|
|
case NGTCP2_ERR_STREAM_DATA_BLOCKED:
|
|
LogDebug("cannot add to stream ", stream.id(), " right now: stream is blocked");
|
|
break;
|
|
case NGTCP2_ERR_STREAM_SHUT_WR:
|
|
LogDebug("cannot write to ", stream.id(), ": stream is shut down");
|
|
break;
|
|
default:
|
|
assert(consumed <= 0);
|
|
LogWarn("Error writing to stream ", stream.id(), ": ", ngtcp2_strerror(nwrite));
|
|
break;
|
|
}
|
|
it = strs.erase(it);
|
|
}
|
|
}
|
|
|
|
// Now try more with stream id -1 and no data: this takes care of things like initial handshake
|
|
// packets, and also finishes off any partially-filled packet from above.
|
|
for (;;)
|
|
{
|
|
auto [nwrite, consumed] = add_stream_data(StreamID{}, nullptr, 0);
|
|
LogDebug("add_stream_data for non-stream returned [", nwrite, ",", consumed, "]");
|
|
assert(consumed <= 0);
|
|
if (nwrite == NGTCP2_ERR_WRITE_MORE)
|
|
{
|
|
LogDebug("Writing non-stream data, and have space left");
|
|
continue;
|
|
}
|
|
if (nwrite < 0)
|
|
{
|
|
LogWarn("Error writing non-stream data: ", ngtcp2_strerror(nwrite));
|
|
break;
|
|
}
|
|
if (nwrite == 0)
|
|
{
|
|
LogDebug("Nothing else to write for non-stream data for now (or we are congested)");
|
|
ngtcp2_conn_stat cstat;
|
|
ngtcp2_conn_get_conn_stat(*this, &cstat);
|
|
LogDebug("Current unacked bytes in flight: ", cstat.bytes_in_flight);
|
|
break;
|
|
}
|
|
|
|
LogDebug("Sending non-stream data packet");
|
|
if (!send_packet(nwrite))
|
|
return;
|
|
}
|
|
|
|
schedule_retransmit();
|
|
}
|
|
|
|
void
|
|
Connection::schedule_retransmit()
|
|
{
|
|
auto expiry = std::chrono::nanoseconds{ngtcp2_conn_get_expiry(*this)};
|
|
if (expiry < 0ns)
|
|
{
|
|
retransmit_timer->repeat(0ms);
|
|
return;
|
|
}
|
|
auto expires_in = std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
expiry - get_time().time_since_epoch());
|
|
LogDebug("Next retransmit in ", expires_in.count(), "ms");
|
|
if (expires_in < 1ms)
|
|
expires_in = 1ms;
|
|
retransmit_timer->repeat(expires_in);
|
|
retransmit_timer->again();
|
|
}
|
|
|
|
int
|
|
Connection::stream_opened(StreamID id)
|
|
{
|
|
LogDebug("New stream ", id);
|
|
auto* serv = server();
|
|
if (!serv)
|
|
{
|
|
LogWarn("We are a client, incoming streams are not accepted");
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
|
|
std::shared_ptr<Stream> stream{new Stream{*this, id, endpoint.default_stream_buffer_size}};
|
|
stream->stream_id = id;
|
|
bool good = true;
|
|
if (serv->stream_open_callback)
|
|
good = serv->stream_open_callback(*stream, tunnel_port);
|
|
if (!good)
|
|
{
|
|
LogDebug("stream_open_callback returned failure, dropping stream ", id);
|
|
ngtcp2_conn_shutdown_stream(*this, id.id, 1);
|
|
io_ready();
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
|
|
[[maybe_unused]] auto [it, ins] = streams.emplace(id, std::move(stream));
|
|
assert(ins);
|
|
LogDebug("Created new incoming stream ", id);
|
|
return 0;
|
|
}
|
|
|
|
int
|
|
Connection::stream_receive(StreamID id, const bstring_view data, bool fin)
|
|
{
|
|
auto str = get_stream(id);
|
|
if (!str->data_callback)
|
|
LogDebug("Dropping incoming data on stream ", str->id(), ": stream has no data callback set");
|
|
else
|
|
{
|
|
bool good = false;
|
|
try
|
|
{
|
|
str->data_callback(*str, data);
|
|
good = true;
|
|
}
|
|
catch (const std::exception& e)
|
|
{
|
|
LogWarn(
|
|
"Stream ",
|
|
str->id(),
|
|
" data callback raised exception (",
|
|
e.what(),
|
|
"); closing stream with app code ",
|
|
STREAM_EXCEPTION_ERROR_CODE);
|
|
}
|
|
catch (...)
|
|
{
|
|
LogWarn(
|
|
"Stream ",
|
|
str->id(),
|
|
" data callback raised an unknown exception; closing stream with app code ",
|
|
STREAM_EXCEPTION_ERROR_CODE);
|
|
}
|
|
if (!good)
|
|
{
|
|
str->close(STREAM_EXCEPTION_ERROR_CODE);
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
}
|
|
if (fin)
|
|
{
|
|
if (str->close_callback)
|
|
str->close_callback(*str, std::nullopt);
|
|
streams.erase(id);
|
|
io_ready();
|
|
}
|
|
else
|
|
{
|
|
ngtcp2_conn_extend_max_stream_offset(*this, id.id, data.size());
|
|
ngtcp2_conn_extend_max_offset(*this, data.size());
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int
|
|
Connection::stream_reset(StreamID id, uint64_t app_code)
|
|
{
|
|
LogDebug(id, " reset with code ", app_code);
|
|
auto it = streams.find(id);
|
|
if (it == streams.end())
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
auto& stream = *it->second;
|
|
const bool was_closing = stream.is_closing;
|
|
stream.is_closing = true;
|
|
if (!was_closing && stream.close_callback)
|
|
{
|
|
LogDebug("Invoke stream close callback");
|
|
stream.close_callback(stream, app_code);
|
|
}
|
|
|
|
streams.erase(it);
|
|
return 0;
|
|
}
|
|
|
|
int
|
|
Connection::stream_ack(StreamID id, size_t size)
|
|
{
|
|
if (auto it = streams.find(id); it != streams.end())
|
|
{
|
|
it->second->acknowledge(size);
|
|
return 0;
|
|
}
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
|
|
Server*
|
|
Connection::server()
|
|
{
|
|
return dynamic_cast<Server*>(&endpoint);
|
|
}
|
|
|
|
Client*
|
|
Connection::client()
|
|
{
|
|
return dynamic_cast<Client*>(&endpoint);
|
|
}
|
|
|
|
int
|
|
Connection::setup_server_crypto_initial()
|
|
{
|
|
auto* s = server();
|
|
assert(s);
|
|
s->null_crypto.server_initial(*this);
|
|
io_ready();
|
|
return 0;
|
|
}
|
|
|
|
ConnectionID
|
|
Connection::make_alias_id(size_t cidlen)
|
|
{
|
|
return endpoint.add_connection_id(*this, cidlen);
|
|
}
|
|
|
|
bool
|
|
Connection::get_handshake_completed()
|
|
{
|
|
return ngtcp2_conn_get_handshake_completed(*this) != 0;
|
|
}
|
|
|
|
int
|
|
Connection::get_streams_available()
|
|
{
|
|
uint64_t left = ngtcp2_conn_get_streams_bidi_left(*this);
|
|
constexpr int max_int = std::numeric_limits<int>::max();
|
|
if (left > static_cast<uint64_t>(max_int))
|
|
return max_int;
|
|
return static_cast<int>(left);
|
|
}
|
|
|
|
const std::shared_ptr<Stream>&
|
|
Connection::open_stream(Stream::data_callback_t data_cb, Stream::close_callback_t close_cb)
|
|
{
|
|
std::shared_ptr<Stream> stream{new Stream{
|
|
*this, std::move(data_cb), std::move(close_cb), endpoint.default_stream_buffer_size}};
|
|
if (int rv = ngtcp2_conn_open_bidi_stream(*this, &stream->stream_id.id, stream.get()); rv != 0)
|
|
throw std::runtime_error{"Stream creation failed: "s + ngtcp2_strerror(rv)};
|
|
|
|
auto& str = streams[stream->stream_id];
|
|
str = std::move(stream);
|
|
|
|
return str;
|
|
}
|
|
|
|
const std::shared_ptr<Stream>&
|
|
Connection::get_stream(StreamID s) const
|
|
{
|
|
return streams.at(s);
|
|
}
|
|
|
|
int
|
|
Connection::init_client()
|
|
{
|
|
endpoint.null_crypto.client_initial(*this);
|
|
|
|
if (int rv = send_magic(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0)
|
|
return rv;
|
|
if (int rv = send_transport_params(NGTCP2_CRYPTO_LEVEL_INITIAL); rv != 0)
|
|
return rv;
|
|
|
|
io_ready();
|
|
return 0;
|
|
}
|
|
|
|
int
|
|
Connection::recv_initial_crypto(std::basic_string_view<uint8_t> data)
|
|
{
|
|
if (data.substr(0, handshake_magic.size()) != handshake_magic)
|
|
{
|
|
LogWarn("Invalid initial crypto frame: did not find expected magic prefix");
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
data.remove_prefix(handshake_magic.size());
|
|
|
|
const bool is_server = ngtcp2_conn_is_server(*this);
|
|
if (is_server)
|
|
{
|
|
// For a server, we receive the transport parameters in the initial packet (prepended by the
|
|
// magic that we just removed):
|
|
if (auto rv = recv_transport_params(data); rv != 0)
|
|
return rv;
|
|
}
|
|
else
|
|
{
|
|
// For a client our initial crypto data should be just the magic string (the packet also
|
|
// contains transport parameters, but they are at HANDSHAKE crypto level and so will result
|
|
// in a second callback to handle them).
|
|
if (!data.empty())
|
|
{
|
|
LogWarn("Invalid initial crypto frame: unexpected post-magic data found");
|
|
return NGTCP2_ERR_CALLBACK_FAILURE;
|
|
}
|
|
}
|
|
|
|
endpoint.null_crypto.install_rx_handshake_key(*this);
|
|
endpoint.null_crypto.install_tx_handshake_key(*this);
|
|
if (is_server)
|
|
endpoint.null_crypto.install_tx_key(*this);
|
|
|
|
return 0;
|
|
}
|
|
|
|
void
|
|
Connection::complete_handshake()
|
|
{
|
|
endpoint.null_crypto.install_rx_key(*this);
|
|
if (!ngtcp2_conn_is_server(*this))
|
|
endpoint.null_crypto.install_tx_key(*this);
|
|
ngtcp2_conn_handshake_completed(*this);
|
|
|
|
if (on_handshake_complete)
|
|
{
|
|
on_handshake_complete(*this);
|
|
on_handshake_complete = nullptr;
|
|
}
|
|
}
|
|
|
|
// ngtcp2 doesn't expose the varint encoding, but it's fairly simple:
|
|
// 0bXXyyyyyy -- XX indicates the encoded size (00=1, 01=2, 10=4, 11=8) and the rest of the bits
|
|
// (6, 14, 30, or 62) are the number, with bytes in network order for >6-bit values.
|
|
|
|
// Returns {value, consumed} where consumed is the number of bytes consumed, or 0 on failure.
|
|
static constexpr std::pair<uint64_t, size_t>
|
|
decode_varint(std::basic_string_view<uint8_t> data)
|
|
{
|
|
std::pair<uint64_t, size_t> result = {0, 0};
|
|
auto& [val, enc_size] = result;
|
|
if (data.empty())
|
|
return result;
|
|
enc_size = 1 << (data[0] >> 6); // first two bits are log₂ of the length
|
|
if (data.size() < enc_size)
|
|
{
|
|
enc_size = 0;
|
|
return result;
|
|
}
|
|
val = data[0] & 0b0011'1111;
|
|
for (size_t i = 1; i < enc_size; i++)
|
|
val = (val << 8) | data[i];
|
|
return result;
|
|
}
|
|
|
|
// Encodes an unsigned integer in QUIC encoding format; return the bytes and the length (bytes
|
|
// beyond `length` are uninitialized).
|
|
static constexpr std::pair<std::array<uint8_t, 8>, uint8_t>
|
|
encode_varint(uint64_t val)
|
|
{
|
|
assert(val < (1ULL << 62));
|
|
std::pair<std::array<uint8_t, 8>, uint8_t> result;
|
|
uint8_t size = val < (1ULL << 6) ? 0 : val < (1ULL << 14) ? 1 : val < (1ULL << 30) ? 2 : 3;
|
|
auto& [enc, len] = result;
|
|
len = 1 << size;
|
|
for (uint8_t i = 1; i <= len; i++)
|
|
{
|
|
enc[len - i] = val & 0xff;
|
|
val >>= 8;
|
|
}
|
|
enc[0] = (enc[0] & 0b00'111111) | (size << 6);
|
|
enc[0] |= size << 6;
|
|
return result;
|
|
}
|
|
|
|
// We add some lokinet-specific data into the transport request and *always* as the first
|
|
// transport parameter, but we do it in a way that the parameter gets ignored by the QUIC
|
|
// protocol, which encodes as {varint[code], varint[length], data}, and requires a code value
|
|
// 31*N+27 (for integer N). Naturally we use N=42, which gives us 1329=0b10100110001 which
|
|
// encodes in QUIC as 0b01000101 0b00110001 (the first two bits of the first byte give the integer
|
|
// size, and the rest are the value in network order).
|
|
static constexpr uint64_t lokinet_transport_param_N = 42;
|
|
static constexpr auto lokinet_metadata_code_raw =
|
|
encode_varint(31 * lokinet_transport_param_N + 27);
|
|
static constexpr std::basic_string_view<uint8_t> lokinet_metadata_code{
|
|
lokinet_metadata_code_raw.first.data(), lokinet_metadata_code_raw.second};
|
|
static_assert(
|
|
lokinet_metadata_code.size() == 2 && lokinet_metadata_code[0] == 0b01000101
|
|
&& lokinet_metadata_code[1] == 0b00110001);
|
|
|
|
int
|
|
Connection::recv_transport_params(std::basic_string_view<uint8_t> data)
|
|
{
|
|
if (data.substr(0, lokinet_metadata_code.size()) != lokinet_metadata_code)
|
|
{
|
|
LogWarn("transport params did not begin with expected lokinet metadata");
|
|
return NGTCP2_ERR_TRANSPORT_PARAM;
|
|
}
|
|
auto [meta_len, meta_len_bytes] = decode_varint(data.substr(lokinet_metadata_code.size()));
|
|
if (meta_len_bytes == 0)
|
|
{
|
|
LogWarn("transport params lokinet metadata has truncated size");
|
|
return NGTCP2_ERR_MALFORMED_TRANSPORT_PARAM;
|
|
}
|
|
std::string_view lokinet_metadata{
|
|
reinterpret_cast<const char*>(
|
|
data.substr(lokinet_metadata_code.size() + meta_len_bytes).data()),
|
|
meta_len};
|
|
LogDebug("Received bencoded lokinet metadata: ", buffer_printer{lokinet_metadata});
|
|
|
|
uint16_t port;
|
|
try
|
|
{
|
|
oxenmq::bt_dict_consumer meta{lokinet_metadata};
|
|
// '#' contains the port the client wants us to forward to
|
|
if (!meta.skip_until("#"))
|
|
{
|
|
LogWarn("transport params # (port) is missing but required");
|
|
return NGTCP2_ERR_TRANSPORT_PARAM;
|
|
}
|
|
port = meta.consume_integer<uint16_t>();
|
|
if (port == 0)
|
|
{
|
|
LogWarn("transport params tunnel port (#) is invalid: 0 is not permitted");
|
|
return NGTCP2_ERR_TRANSPORT_PARAM;
|
|
}
|
|
LogDebug("decoded lokinet tunnel port = ", port);
|
|
}
|
|
catch (const oxenmq::bt_deserialize_invalid& c)
|
|
{
|
|
LogWarn("transport params lokinet metadata is invalid: ", c.what());
|
|
return NGTCP2_ERR_TRANSPORT_PARAM;
|
|
}
|
|
|
|
const bool is_server = ngtcp2_conn_is_server(*this);
|
|
|
|
if (is_server)
|
|
{
|
|
tunnel_port = port;
|
|
}
|
|
else
|
|
{
|
|
// Make sure the server reflected the proper port
|
|
if (tunnel_port != port)
|
|
{
|
|
LogWarn("server returned invalid port; expected ", tunnel_port, ", got ", port);
|
|
return NGTCP2_ERR_TRANSPORT_PARAM;
|
|
}
|
|
}
|
|
|
|
ngtcp2_transport_params params;
|
|
|
|
auto exttype = is_server ? NGTCP2_TRANSPORT_PARAMS_TYPE_CLIENT_HELLO
|
|
: NGTCP2_TRANSPORT_PARAMS_TYPE_ENCRYPTED_EXTENSIONS;
|
|
|
|
auto rv = ngtcp2_decode_transport_params(¶ms, exttype, data.data(), data.size());
|
|
LogDebug("Decode transport params ", rv == 0 ? "success" : "fail: "s + ngtcp2_strerror(rv));
|
|
LogTrace("params orig dcid = ", ConnectionID(params.original_dcid));
|
|
LogTrace("params init scid = ", ConnectionID(params.initial_scid));
|
|
if (rv == 0)
|
|
{
|
|
rv = ngtcp2_conn_set_remote_transport_params(*this, ¶ms);
|
|
LogDebug(
|
|
"Set remote transport params ", rv == 0 ? "success" : "fail: "s + ngtcp2_strerror(rv));
|
|
}
|
|
|
|
if (rv != 0)
|
|
{
|
|
ngtcp2_conn_set_tls_error(*this, rv);
|
|
return rv;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
// Sends our magic string at the given level. This fixed magic string is taking the place of TLS
|
|
// parameters in full QUIC.
|
|
int
|
|
Connection::send_magic(ngtcp2_crypto_level level)
|
|
{
|
|
return ngtcp2_conn_submit_crypto_data(
|
|
*this, level, handshake_magic.data(), handshake_magic.size());
|
|
}
|
|
|
|
template <typename String>
|
|
static void
|
|
copy_and_advance(uint8_t*& buf, const String& s)
|
|
{
|
|
static_assert(sizeof(typename String::value_type) == 1, "not a char-compatible type");
|
|
std::memcpy(buf, s.data(), s.size());
|
|
buf += s.size();
|
|
}
|
|
|
|
// Sends transport parameters. `level` is expected to be INITIAL for clients (which send the
|
|
// transport parameters in the initial packet), or HANDSHAKE for servers.
|
|
int
|
|
Connection::send_transport_params(ngtcp2_crypto_level level)
|
|
{
|
|
ngtcp2_transport_params tparams;
|
|
ngtcp2_conn_get_local_transport_params(*this, &tparams);
|
|
|
|
assert(conn_buffer.empty());
|
|
static_assert(NGTCP2_MAX_PKTLEN_IPV4 > NGTCP2_MAX_PKTLEN_IPV6);
|
|
conn_buffer.resize(NGTCP2_MAX_PKTLEN_IPV4);
|
|
|
|
auto* buf = u8data(conn_buffer);
|
|
auto* bufend = buf + conn_buffer.size();
|
|
{
|
|
// Send our first parameter, the lokinet metadata, in a QUIC-compatible way (by using a
|
|
// reserved field code that QUIC parsers must ignore); currently we only include the port in
|
|
// here (from the client to tell the server what it's trying to reach, and reflected from
|
|
// the server for the client to verify).
|
|
std::string lokinet_metadata = bt_serialize(oxenmq::bt_dict{
|
|
{"#", tunnel_port},
|
|
});
|
|
copy_and_advance(buf, lokinet_metadata_code);
|
|
auto [bytes, size] = encode_varint(lokinet_metadata.size());
|
|
copy_and_advance(buf, std::basic_string_view{bytes.data(), size});
|
|
copy_and_advance(buf, lokinet_metadata);
|
|
assert(buf < bufend);
|
|
}
|
|
|
|
const bool is_server = ngtcp2_conn_is_server(*this);
|
|
auto exttype = is_server ? NGTCP2_TRANSPORT_PARAMS_TYPE_ENCRYPTED_EXTENSIONS
|
|
: NGTCP2_TRANSPORT_PARAMS_TYPE_CLIENT_HELLO;
|
|
|
|
if (ngtcp2_ssize nwrite = ngtcp2_encode_transport_params(buf, bufend - buf, exttype, &tparams);
|
|
nwrite >= 0)
|
|
{
|
|
assert(nwrite > 0);
|
|
conn_buffer.resize(buf - u8data(conn_buffer) + nwrite);
|
|
}
|
|
else
|
|
{
|
|
conn_buffer.clear();
|
|
return nwrite;
|
|
}
|
|
LogDebug("encoded transport params: ", buffer_printer{conn_buffer});
|
|
return ngtcp2_conn_submit_crypto_data(*this, level, u8data(conn_buffer), conn_buffer.size());
|
|
}
|
|
|
|
} // namespace llarp::quic
|