lokinet/llarp/quic/endpoint.cpp
2021-04-19 06:57:46 -04:00

524 lines
15 KiB
C++

#include "endpoint.hpp"
#include "client.hpp"
#include "server.hpp"
#include <llarp/crypto/crypto.hpp>
#include <llarp/util/logging/buffer.hpp>
#include <iostream>
#include <random>
#include <variant>
#include <uvw/timer.h>
#include <oxenmq/variant.h>
extern "C"
{
#include <sodium/crypto_generichash.h>
#include <sodium/randombytes.h>
}
namespace llarp::quic
{
Endpoint::Endpoint(std::optional<Address> addr, std::shared_ptr<uvw::Loop> loop_)
: loop{std::move(loop_)}
{
randombytes_buf(static_secret.data(), static_secret.size());
// Create and bind the UDP socket. We can't use libuv's UDP socket here because it doesn't
// give us the ability to set up the ECN field as QUIC requires.
auto fd = socket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (fd == -1)
throw std::runtime_error{"Failed to open socket: "s + strerror(errno)};
if (addr)
{
assert(addr->sockaddr_size() == sizeof(sockaddr_in)); // FIXME: IPv4-only for now
auto rv = bind(fd, *addr, addr->sockaddr_size());
if (rv == -1)
throw std::runtime_error{
"Failed to bind UDP socket to " + addr->to_string() + ": " + strerror(errno)};
}
// Get our address via the socket in case `addr` is using anyaddr/anyport.
sockaddr_any sa;
socklen_t salen = sizeof(sa);
// FIXME: if I didn't call bind above then do I need to call bind() before this (with
// anyaddr/anyport)?
getsockname(fd, &sa.sa, &salen);
assert(salen == sizeof(sockaddr_in)); // FIXME: IPv4-only for now
local = {&sa, salen};
LogDebug("Bound to ", local, addr ? "" : " (auto-selected)");
// Set up the socket to provide us with incoming ECN (IP_TOS) info
// NB: This is for IPv4; on AF_INET6 this would be IPPROTO_IPV6, IPV6_RECVTCLASS
if (uint8_t want_tos = 1;
- 1
== setsockopt(
fd, IPPROTO_IP, IP_RECVTOS, &want_tos, static_cast<socklen_t>(sizeof(want_tos))))
throw std::runtime_error{"Failed to set ECN on socket: "s + strerror(errno)};
// Wire up our recv buffer structures into what recvmmsg() wants
buf.resize(max_buf_size * msgs.size());
for (size_t i = 0; i < msgs.size(); i++)
{
auto& iov = msgs_iov[i];
iov.iov_base = buf.data() + max_buf_size * i;
iov.iov_len = max_buf_size;
#ifdef LOKINET_HAVE_RECVMMSG
auto& mh = msgs[i].msg_hdr;
#else
auto& mh = msgs[i];
#endif
mh.msg_name = &msgs_addr[i];
mh.msg_namelen = sizeof(msgs_addr[i]);
mh.msg_iov = &iov;
mh.msg_iovlen = 1;
mh.msg_control = msgs_cmsg[i].data();
mh.msg_controllen = msgs_cmsg[i].size();
}
// Let uv do its stuff
poll = loop->resource<uvw::PollHandle>(fd);
poll->on<uvw::PollEvent>([this](const auto&, auto&) { on_readable(); });
poll->start(uvw::PollHandle::Event::READABLE);
// Set up a callback every 250ms to clean up stale sockets, etc.
expiry_timer = loop->resource<uvw::TimerHandle>();
expiry_timer->on<uvw::TimerEvent>([this](const auto&, auto&) { check_timeouts(); });
expiry_timer->start(250ms, 250ms);
LogDebug("Created endpoint");
}
Endpoint::~Endpoint()
{
if (poll)
poll->close();
if (expiry_timer)
expiry_timer->close();
}
int
Endpoint::socket_fd() const
{
return poll->fd();
}
void
Endpoint::on_readable()
{
LogDebug("poll callback on readable");
#ifdef LOKINET_HAVE_RECVMMSG
// NB: recvmmsg is linux-specific but ought to offer some performance benefits
int n_msg = recvmmsg(socket_fd(), msgs.data(), msgs.size(), 0, nullptr);
if (n_msg == -1)
{
if (errno != EAGAIN && errno != ENOTCONN)
LogWarn("Error recv'ing from ", local.to_string(), ": ", strerror(errno));
return;
}
LogDebug("Recv'd ", n_msg, " messages");
for (int i = 0; i < n_msg; i++)
{
auto& [msg_hdr, msg_len] = msgs[i];
bstring_view data{buf.data() + i * max_buf_size, msg_len};
#else
for (size_t i = 0; i < N_msgs; i++)
{
auto& msg_hdr = msgs[0];
auto n_bytes = recvmsg(socket_fd(), &msg_hdr, 0);
if (n_bytes == -1 && errno != EAGAIN && errno != ENOTCONN)
LogWarn("Error recv'ing from ", local.to_string(), ": ", strerror(errno));
if (n_bytes <= 0)
return;
auto msg_len = static_cast<unsigned int>(n_bytes);
bstring_view data{buf.data(), msg_len};
#endif
LogDebug(
"header [",
msg_hdr.msg_namelen,
"]: ",
buffer_printer{reinterpret_cast<char*>(msg_hdr.msg_name), msg_hdr.msg_namelen});
if (!msg_hdr.msg_name || msg_hdr.msg_namelen != sizeof(sockaddr_in))
{ // FIXME: IPv6 support?
LogWarn("Invalid/unknown source address, dropping packet");
continue;
}
Packet pkt{
Path{local, reinterpret_cast<const sockaddr_any*>(msg_hdr.msg_name), msg_hdr.msg_namelen},
data,
ngtcp2_pkt_info{.ecn = 0}};
// Go look for the ECN header field on the incoming packet
for (auto cmsg = CMSG_FIRSTHDR(&msg_hdr); cmsg; cmsg = CMSG_NXTHDR(&msg_hdr, cmsg))
{
// IPv4; for IPv6 these would be IPPROTO_IPV6 and IPV6_TCLASS
if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS && cmsg->cmsg_len)
{
pkt.info.ecn = *reinterpret_cast<uint8_t*>(CMSG_DATA(cmsg));
}
}
LogDebug(
i,
"[",
pkt.path,
",ecn=0x",
std::hex,
+pkt.info.ecn,
std::dec,
"]: received ",
msg_len,
" bytes");
handle_packet(pkt);
LogDebug("Done handling packet");
#ifdef LOKINET_HAVE_RECVMMSG // Help editor's { } matching:
}
#else
}
#endif
}
std::optional<ConnectionID>
Endpoint::handle_packet_init(const Packet& p)
{
version_info vi;
auto rv = ngtcp2_pkt_decode_version_cid(
&vi.version,
&vi.dcid,
&vi.dcid_len,
&vi.scid,
&vi.scid_len,
u8data(p.data),
p.data.size(),
NGTCP2_MAX_CIDLEN);
if (rv == 1)
{ // 1 means Version Negotiation should be sent and otherwise the packet should be ignored
send_version_negotiation(vi, p.path.remote);
return std::nullopt;
}
if (rv != 0)
{
LogWarn("QUIC packet header decode failed: ", ngtcp2_strerror(rv));
return std::nullopt;
}
if (vi.dcid_len > ConnectionID::max_size())
{
LogWarn("Internal error: destination ID is longer than should be allowed");
return std::nullopt;
}
return std::make_optional<ConnectionID>(vi.dcid, vi.dcid_len);
}
void
Endpoint::handle_conn_packet(Connection& conn, const Packet& p)
{
if (ngtcp2_conn_is_in_closing_period(conn))
{
LogDebug("Connection is in closing period, dropping");
close_connection(conn);
return;
}
if (conn.draining)
{
LogDebug("Connection is draining, dropping");
// "draining" state means we received a connection close and we're keeping the
// connection alive just to catch (and discard) straggling packets that arrive
// out of order w.r.t to connection close.
return;
}
if (auto result = read_packet(p, conn); !result)
{
LogWarn("Read packet failed! ", ngtcp2_strerror(result.error_code));
}
// FIXME - reset idle timer?
LogDebug("Done with incoming packet");
}
io_result
Endpoint::read_packet(const Packet& p, Connection& conn)
{
LogDebug("Reading packet from ", p.path);
auto rv =
ngtcp2_conn_read_pkt(conn, p.path, &p.info, u8data(p.data), p.data.size(), get_timestamp());
if (rv == 0)
conn.io_ready();
else
LogWarn("read pkt error: ", ngtcp2_strerror(rv));
if (rv == NGTCP2_ERR_DRAINING)
start_draining(conn);
else if (rv == NGTCP2_ERR_DROP_CONN)
delete_conn(conn.base_cid);
return {rv};
}
void
Endpoint::update_ecn(uint32_t ecn)
{
assert(ecn <= std::numeric_limits<uint8_t>::max());
if (ecn_curr != ecn)
{
if (-1
== setsockopt(socket_fd(), IPPROTO_IP, IP_TOS, &ecn, static_cast<socklen_t>(sizeof(ecn))))
LogWarn("setsockopt failed to set IP_TOS: ", strerror(errno));
// IPv6 version:
// int tclass = this->ecn;
// setsockopt(socket_fd(), IPPROTO_IPV6, IPV6_TCLASS, &tclass,
// static_cast<socklen_t>(sizeof(tclass)));
ecn_curr = ecn;
}
}
io_result
Endpoint::send_packet(const Address& to, bstring_view data, uint32_t ecn)
{
iovec msg_iov;
msg_iov.iov_base = const_cast<std::byte*>(data.data());
msg_iov.iov_len = data.size();
msghdr msg{};
msg.msg_name = &const_cast<sockaddr&>(reinterpret_cast<const sockaddr&>(to));
msg.msg_namelen = sizeof(sockaddr_in);
msg.msg_iov = &msg_iov;
msg.msg_iovlen = 1;
auto fd = socket_fd();
update_ecn(ecn);
ssize_t nwrite = 0;
do
{
nwrite = sendmsg(fd, &msg, 0);
} while (nwrite == -1 && errno == EINTR);
if (nwrite == -1)
{
LogWarn("sendmsg failed: ", strerror(errno));
return {errno};
}
LogDebug(
"[",
to.to_string(),
",ecn=0x",
std::hex,
+ecn_curr,
std::dec,
"]: sent ",
nwrite,
" bytes");
return {};
}
void
Endpoint::send_version_negotiation(const version_info& vi, const Address& source)
{
std::array<std::byte, NGTCP2_MAX_PKTLEN_IPV4> buf;
std::array<uint32_t, NGTCP2_PROTO_VER_MAX - NGTCP2_PROTO_VER_MIN + 2> versions;
std::iota(versions.begin() + 1, versions.end(), NGTCP2_PROTO_VER_MIN);
// we're supposed to send some 0x?a?a?a?a version to trigger version negotiation
versions[0] = 0x1a2a3a4au;
CSRNG rng{};
auto nwrote = ngtcp2_pkt_write_version_negotiation(
u8data(buf),
buf.size(),
std::uniform_int_distribution<uint8_t>{0, 255}(rng),
vi.dcid,
vi.dcid_len,
vi.scid,
vi.scid_len,
versions.data(),
versions.size());
if (nwrote < 0)
LogWarn("Failed to construct version negotiation packet: ", ngtcp2_strerror(nwrote));
if (nwrote <= 0)
return;
send_packet(source, bstring_view{buf.data(), static_cast<size_t>(nwrote)}, 0);
}
void
Endpoint::close_connection(Connection& conn, uint64_t code, bool application)
{
LogDebug("Closing connection ", conn.base_cid);
if (!conn.closing)
{
conn.conn_buffer.resize(max_pkt_size_v4);
Path path;
ngtcp2_pkt_info pi;
auto write_close_func =
application ? ngtcp2_conn_write_application_close : ngtcp2_conn_write_connection_close;
auto written = write_close_func(
conn,
path,
&pi,
u8data(conn.conn_buffer),
conn.conn_buffer.size(),
code,
get_timestamp());
if (written <= 0)
{
LogWarn(
"Failed to write connection close packet: ",
written < 0 ? ngtcp2_strerror(written) : "unknown error: closing is 0 bytes??");
return;
}
assert(written <= (long)conn.conn_buffer.size());
conn.conn_buffer.resize(written);
conn.closing = true;
// FIXME: ipv6
assert(path.local.sockaddr_size() == sizeof(sockaddr_in));
assert(path.remote.sockaddr_size() == sizeof(sockaddr_in));
conn.path = path;
}
assert(conn.closing && !conn.conn_buffer.empty());
if (auto sent = send_packet(conn.path.remote, conn.conn_buffer, 0); !sent)
{
LogWarn(
"Failed to send packet: ",
strerror(sent.error_code),
"; removing connection ",
conn.base_cid);
delete_conn(conn.base_cid);
return;
}
}
/// Puts a connection into draining mode (i.e. after getting a connection close). This will
/// keep the connection registered for the recommended 3*Probe Timeout, during which we drop
/// packets that use the connection id and after which we will forget about it.
void
Endpoint::start_draining(Connection& conn)
{
if (conn.draining)
return;
LogDebug("Putting ", conn.base_cid, " into draining mode");
conn.draining = true;
// Recommended draining time is 3*Probe Timeout
draining.emplace(conn.base_cid, get_time() + ngtcp2_conn_get_pto(conn) * 3 * 1ns);
}
void
Endpoint::check_timeouts()
{
auto now = get_time();
uint64_t now_ts = get_timestamp(now);
// Destroy any connections that are finished draining
bool cleanup = false;
while (!draining.empty() && draining.front().second < now)
{
if (auto it = conns.find(draining.front().first); it != conns.end())
{
if (std::holds_alternative<primary_conn_ptr>(it->second))
cleanup = true;
LogDebug("Deleting connection ", it->first);
conns.erase(it);
}
draining.pop();
}
if (cleanup)
clean_alias_conns();
for (auto it = conns.begin(); it != conns.end(); ++it)
{
if (auto* conn_ptr = std::get_if<primary_conn_ptr>(&it->second))
{
Connection& conn = **conn_ptr;
auto exp = ngtcp2_conn_get_idle_expiry(conn);
if (exp >= now_ts || conn.draining)
continue;
start_draining(conn);
}
}
}
std::pair<std::shared_ptr<Connection>, bool>
Endpoint::get_conn(const ConnectionID& cid)
{
if (auto it = conns.find(cid); it != conns.end())
{
if (auto* wptr = std::get_if<alias_conn_ptr>(&it->second))
return {wptr->lock(), true};
return {var::get<primary_conn_ptr>(it->second), false};
}
return {nullptr, false};
}
bool
Endpoint::delete_conn(const ConnectionID& cid)
{
auto it = conns.find(cid);
if (it == conns.end())
{
LogDebug("Cannot delete connection ", cid, ": cid not found");
return false;
}
bool primary = std::holds_alternative<primary_conn_ptr>(it->second);
LogDebug("Deleting ", primary ? "primary" : "alias", " connection ", cid);
conns.erase(it);
if (primary)
clean_alias_conns();
return true;
}
void
Endpoint::clean_alias_conns()
{
for (auto it = conns.begin(); it != conns.end();)
{
if (auto* conn_wptr = std::get_if<alias_conn_ptr>(&it->second);
conn_wptr && conn_wptr->expired())
it = conns.erase(it);
else
++it;
}
}
ConnectionID
Endpoint::add_connection_id(Connection& conn, size_t cid_length)
{
ConnectionID cid;
for (bool inserted = false; !inserted;)
{
cid = ConnectionID::random(cid_length);
inserted = conns.emplace(cid, conn.weak_from_this()).second;
}
LogDebug("Created cid ", cid, " alias for ", conn.base_cid);
return cid;
}
void
Endpoint::make_stateless_reset_token(const ConnectionID& cid, unsigned char* dest)
{
crypto_generichash_state state;
crypto_generichash_init(&state, nullptr, 0, NGTCP2_STATELESS_RESET_TOKENLEN);
crypto_generichash_update(&state, u8data(static_secret), static_secret.size());
crypto_generichash_update(&state, cid.data, cid.datalen);
crypto_generichash_final(&state, dest, NGTCP2_STATELESS_RESET_TOKENLEN);
}
} // namespace llarp::quic