clean up address/ip management code to use std::variant and std::optional

pull/1576/head
Jeff Becker 3 years ago
parent 1885b1cae9
commit fc9b09bdbc
No known key found for this signature in database
GPG Key ID: F357B3B42F6F9B05

@ -41,11 +41,16 @@ namespace llarp
void void
SendPacketToRemote(const llarp_buffer_t&) override{}; SendPacketToRemote(const llarp_buffer_t&) override{};
huint128_t huint128_t ObtainIPForAddr(std::variant<service::Address, RouterID>) override
ObtainIPForAddr(const AlignedBuffer<32>&, bool) override
{ {
return {0}; return {0};
} }
std::optional<std::variant<service::Address, RouterID>> ObtainAddrForIP(
huint128_t) const override
{
return std::nullopt;
}
}; };
} // namespace handlers } // namespace handlers
} // namespace llarp } // namespace llarp

@ -1,5 +1,6 @@
#include <algorithm> #include <algorithm>
#include <llarp/net/net.hpp> #include <llarp/net/net.hpp>
#include <variant>
// harmless on other platforms // harmless on other platforms
#define __USE_MINGW_ANSI_STDIO 1 #define __USE_MINGW_ANSI_STDIO 1
#include "tun.hpp" #include "tun.hpp"
@ -297,32 +298,6 @@ namespace llarp
return msg.questions[0].IsLocalhost(); return msg.questions[0].IsLocalhost();
} }
template <>
bool
TunEndpoint::FindAddrForIP(service::Address& addr, huint128_t ip)
{
auto itr = m_IPToAddr.find(ip);
if (itr != m_IPToAddr.end() and not m_SNodes[itr->second])
{
addr = service::Address(itr->second.as_array());
return true;
}
return false;
}
template <>
bool
TunEndpoint::FindAddrForIP(RouterID& addr, huint128_t ip)
{
auto itr = m_IPToAddr.find(ip);
if (itr != m_IPToAddr.end() and m_SNodes[itr->second])
{
addr = RouterID(itr->second.as_array());
return true;
}
return false;
}
static dns::Message& static dns::Message&
clear_dns_message(dns::Message& msg) clear_dns_message(dns::Message& msg)
{ {
@ -333,13 +308,25 @@ namespace llarp
return msg; return msg;
} }
std::optional<std::variant<service::Address, RouterID>>
TunEndpoint::ObtainAddrForIP(huint128_t ip) const
{
auto itr = m_IPToAddr.find(ip);
if (itr == m_IPToAddr.end())
return std::nullopt;
if (m_SNodes.at(itr->second))
return RouterID{itr->second.as_array()};
else
return service::Address{itr->second.as_array()};
}
bool bool
TunEndpoint::HandleHookedDNSMessage(dns::Message msg, std::function<void(dns::Message)> reply) TunEndpoint::HandleHookedDNSMessage(dns::Message msg, std::function<void(dns::Message)> reply)
{ {
auto ReplyToSNodeDNSWhenReady = [self = this, reply = reply]( auto ReplyToSNodeDNSWhenReady = [self = this, reply = reply](
RouterID snode, auto msg, bool isV6) -> bool { RouterID snode, auto msg, bool isV6) -> bool {
return self->EnsurePathToSNode(snode, [=](const RouterID&, exit::BaseSession_ptr s) { return self->EnsurePathToSNode(snode, [=](const RouterID&, exit::BaseSession_ptr s) {
self->SendDNSReply(snode, s, msg, reply, true, isV6); self->SendDNSReply(snode, s, msg, reply, isV6);
}); });
}; };
auto ReplyToLokiDNSWhenReady = [self = this, reply = reply]( auto ReplyToLokiDNSWhenReady = [self = this, reply = reply](
@ -349,7 +336,7 @@ namespace llarp
return self->EnsurePathToService( return self->EnsurePathToService(
addr, addr,
[=](const Address&, OutboundContext* ctx) { [=](const Address&, OutboundContext* ctx) {
self->SendDNSReply(addr, ctx, msg, reply, false, isV6); self->SendDNSReply(addr, ctx, msg, reply, isV6);
}, },
2s); 2s);
}; };
@ -666,17 +653,10 @@ namespace llarp
reply(msg); reply(msg);
return true; return true;
} }
RouterID snodeAddr;
if (FindAddrForIP(snodeAddr, ip)) if (auto maybe = ObtainAddrForIP(ip))
{
msg.AddAReply(snodeAddr.ToString());
reply(msg);
return true;
}
service::Address lokiAddr;
if (FindAddrForIP(lokiAddr, ip))
{ {
msg.AddAReply(lokiAddr.ToString()); std::visit([&msg](auto&& result) { msg.AddAReply(result.ToString()); }, *maybe);
reply(msg); reply(msg);
return true; return true;
} }
@ -1043,9 +1023,12 @@ namespace llarp
if (t != service::ProtocolType::TrafficV4 && t != service::ProtocolType::TrafficV6 if (t != service::ProtocolType::TrafficV4 && t != service::ProtocolType::TrafficV6
&& t != service::ProtocolType::Exit) && t != service::ProtocolType::Exit)
return false; return false;
AlignedBuffer<32> addr; std::variant<service::Address, RouterID> addr;
bool snode = false; if (auto maybe = GetEndpointWithConvoTag(tag))
if (!GetEndpointWithConvoTag(tag, addr, snode)) {
addr = *maybe;
}
else
return false; return false;
huint128_t src, dst; huint128_t src, dst;
@ -1056,7 +1039,7 @@ namespace llarp
if (m_state->m_ExitEnabled) if (m_state->m_ExitEnabled)
{ {
// exit side from exit // exit side from exit
src = ObtainIPForAddr(addr, snode); src = ObtainIPForAddr(addr);
if (t == service::ProtocolType::Exit) if (t == service::ProtocolType::Exit)
{ {
if (pkt.IsV4()) if (pkt.IsV4())
@ -1088,16 +1071,22 @@ namespace llarp
} }
// find what exit we think this should be for // find what exit we think this should be for
const auto mapped = m_ExitMap.FindAll(src); const auto mapped = m_ExitMap.FindAll(src);
if (mapped.count(service::Address{addr}) == 0 or IsBogon(src)) if (IsBogon(src))
{
// we got exit traffic from someone who we should not have gotten it from
return false; return false;
if (const auto ptr = std::get_if<service::Address>(&addr))
{
if (mapped.count(*ptr) == 0)
{
// we got exit traffic from someone who we should not have gotten it from
return false;
}
} }
} }
else else
{ {
// snapp traffic // snapp traffic
src = ObtainIPForAddr(addr, snode); src = ObtainIPForAddr(addr);
dst = m_OurIP; dst = m_OurIP;
} }
HandleWriteIPPacket(buf, src, dst, seqno); HandleWriteIPPacket(buf, src, dst, seqno);
@ -1136,10 +1125,20 @@ namespace llarp
} }
huint128_t huint128_t
TunEndpoint::ObtainIPForAddr(const AlignedBuffer<32>& ident, bool snode) TunEndpoint::ObtainIPForAddr(std::variant<service::Address, RouterID> addr)
{ {
llarp_time_t now = Now(); llarp_time_t now = Now();
huint128_t nextIP = {0}; huint128_t nextIP = {0};
AlignedBuffer<32> ident{};
bool snode = false;
std::visit([&ident](auto&& val) { ident = val.data(); }, addr);
if (std::get_if<RouterID>(&addr))
{
snode = true;
}
{ {
// previously allocated address // previously allocated address
auto itr = m_AddrToIP.find(ident); auto itr = m_AddrToIP.find(ident);

@ -13,6 +13,8 @@
#include <future> #include <future>
#include <queue> #include <queue>
#include <type_traits>
#include <variant>
namespace llarp namespace llarp
{ {
@ -121,23 +123,8 @@ namespace llarp
HasLocalIP(const huint128_t& ip) const; HasLocalIP(const huint128_t& ip) const;
/// get a key for ip address /// get a key for ip address
template <typename Addr_t> std::optional<std::variant<service::Address, RouterID>>
Addr_t ObtainAddrForIP(huint128_t ip) const override;
ObtainAddrForIP(huint128_t ip, bool isSNode)
{
Addr_t addr;
auto itr = m_IPToAddr.find(ip);
if (itr != m_IPToAddr.end() and m_SNodes[itr->second] == isSNode)
{
addr = Addr_t(itr->second);
}
// found
return addr;
}
template <typename Addr_t>
bool
FindAddrForIP(Addr_t& addr, huint128_t ip);
bool bool
HasAddress(const AlignedBuffer<32>& addr) const HasAddress(const AlignedBuffer<32>& addr) const
@ -147,7 +134,7 @@ namespace llarp
/// get ip address for key unconditionally /// get ip address for key unconditionally
huint128_t huint128_t
ObtainIPForAddr(const AlignedBuffer<32>& addr, bool serviceNode) override; ObtainIPForAddr(std::variant<service::Address, RouterID> addr) override;
/// flush network traffic /// flush network traffic
void void
@ -214,12 +201,11 @@ namespace llarp
Endpoint_t ctx, Endpoint_t ctx,
std::shared_ptr<dns::Message> query, std::shared_ptr<dns::Message> query,
std::function<void(dns::Message)> reply, std::function<void(dns::Message)> reply,
bool snode,
bool sendIPv6) bool sendIPv6)
{ {
if (ctx) if (ctx)
{ {
huint128_t ip = ObtainIPForAddr(addr, snode); huint128_t ip = ObtainIPForAddr(addr);
query->answers.clear(); query->answers.clear();
query->AddINReply(ip, sendIPv6); query->AddINReply(ip, sendIPv6);
} }

@ -31,6 +31,7 @@
#include <llarp/tooling/dht_event.hpp> #include <llarp/tooling/dht_event.hpp>
#include <llarp/quic/server.hpp> #include <llarp/quic/server.hpp>
#include <optional>
#include <utility> #include <utility>
#include <llarp/quic/server.hpp> #include <llarp/quic/server.hpp>
@ -210,29 +211,23 @@ namespace llarp
return routers.find(remote) != routers.end(); return routers.find(remote) != routers.end();
} }
bool std::optional<std::variant<Address, RouterID>>
Endpoint::GetEndpointWithConvoTag( Endpoint::GetEndpointWithConvoTag(ConvoTag tag) const
const ConvoTag tag, llarp::AlignedBuffer<32>& addr, bool& snode) const
{ {
auto itr = Sessions().find(tag); auto itr = Sessions().find(tag);
if (itr != Sessions().end()) if (itr != Sessions().end())
{ {
snode = false; return itr->second.remote.Addr();
addr = itr->second.remote.Addr();
return true;
} }
for (const auto& item : m_state->m_SNodeSessions) for (const auto& item : m_state->m_SNodeSessions)
{ {
if (item.second.second == tag) if (item.second.second == tag)
{ {
snode = true; return item.first;
addr = item.first;
return true;
} }
} }
return std::nullopt;
return false;
} }
bool bool
@ -1345,7 +1340,7 @@ namespace llarp
// some day :DDDDD // some day :DDDDD
tag.Randomize(); tag.Randomize();
const auto src = xhtonl(net::TruncateV6(GetIfAddr())); const auto src = xhtonl(net::TruncateV6(GetIfAddr()));
const auto dst = xhtonl(net::TruncateV6(ObtainIPForAddr(snode, true))); const auto dst = xhtonl(net::TruncateV6(ObtainIPForAddr(snode)));
auto session = std::make_shared<exit::SNodeSession>( auto session = std::make_shared<exit::SNodeSession>(
snode, snode,

@ -17,6 +17,7 @@
#include "lookup.hpp" #include "lookup.hpp"
#include <llarp/hook/ihook.hpp> #include <llarp/hook/ihook.hpp>
#include <llarp/util/compare_ptr.hpp> #include <llarp/util/compare_ptr.hpp>
#include <optional>
#include <unordered_map> #include <unordered_map>
#include "endpoint_types.hpp" #include "endpoint_types.hpp"
@ -178,8 +179,11 @@ namespace llarp
void void
SetAuthInfoForEndpoint(Address remote, AuthInfo info); SetAuthInfoForEndpoint(Address remote, AuthInfo info);
virtual huint128_t virtual huint128_t ObtainIPForAddr(std::variant<Address, RouterID>) = 0;
ObtainIPForAddr(const AlignedBuffer<32>& addr, bool serviceNode) = 0;
/// get a key for ip address
virtual std::optional<std::variant<service::Address, RouterID>>
ObtainAddrForIP(huint128_t ip) const = 0;
// virtual bool // virtual bool
// HasServiceAddress(const AlignedBuffer< 32 >& addr) const = 0; // HasServiceAddress(const AlignedBuffer< 32 >& addr) const = 0;
@ -273,13 +277,9 @@ namespace llarp
void void
BlacklistSNode(const RouterID snode) override; BlacklistSNode(const RouterID snode) override;
/// return true if we have a convotag as an exit session /// maybe get an endpoint variant given its convo tag
/// or as a hidden service session std::optional<std::variant<Address, RouterID>>
/// set addr and issnode GetEndpointWithConvoTag(ConvoTag t) const;
///
/// return false if we don't have either
bool
GetEndpointWithConvoTag(const ConvoTag t, AlignedBuffer<32>& addr, bool& issnode) const;
bool bool
HasConvoTag(const ConvoTag& t) const override; HasConvoTag(const ConvoTag& t) const override;

@ -30,16 +30,20 @@ namespace llarp
{ {
if (handlePacket) if (handlePacket)
{ {
AlignedBuffer<32> addr; service::Address addr{};
bool isSnode = false; if (auto maybe = GetEndpointWithConvoTag(tag))
if (not GetEndpointWithConvoTag(tag, addr, isSnode)) {
if (auto ptr = std::get_if<service::Address>(&*maybe))
addr = *ptr;
else
return false;
}
else
return false; return false;
if (isSnode)
return true;
std::vector<byte_t> pkt; std::vector<byte_t> pkt;
pkt.resize(pktbuf.sz); pkt.resize(pktbuf.sz);
std::copy_n(pktbuf.base, pktbuf.sz, pkt.data()); std::copy_n(pktbuf.base, pktbuf.sz, pkt.data());
handlePacket(service::Address(addr), std::move(pkt), proto); handlePacket(addr, std::move(pkt), proto);
} }
return true; return true;
} }
@ -56,12 +60,17 @@ namespace llarp
return false; return false;
} }
llarp::huint128_t llarp::huint128_t ObtainIPForAddr(std::variant<service::Address, RouterID>) override
ObtainIPForAddr(const llarp::AlignedBuffer<32>&, bool) override
{ {
return {0}; return {0};
} }
std::optional<std::variant<service::Address, RouterID>> ObtainAddrForIP(
huint128_t) const override
{
return std::nullopt;
}
std::string std::string
GetIfName() const override GetIfName() const override
{ {

Loading…
Cancel
Save