Simplify dll loading via static function pointers

- Replaces RAII handling of DLLs with global function pointers.  (We
  don't unload the dll this way, but that seems unnecessary anyway).
- Simplifies code by just needing to call an init function, but not
  needing to pass around an object holding the function pointers.
- Adds a templated dll loader that takes the dll and a list of
  name/pointer pairs to load the dll and set the pointers in one shot.
pull/1969/head
Jason Rhinelander 2 years ago
parent 281fbe57f7
commit 9921dd6c77
No known key found for this signature in database
GPG Key ID: C4992CE7A88D4262

@ -87,15 +87,11 @@ namespace llarp::win32
}
}
std::shared_ptr<WintunContext> _wintun;
WinDivert_API m_WinDivert{};
public:
VPNPlatform(const VPNPlatform&) = delete;
VPNPlatform(VPNPlatform&&) = delete;
VPNPlatform(llarp::Context* ctx) : Platform{}, _ctx{ctx}, _wintun{WintunContext_new()}
VPNPlatform(llarp::Context* ctx) : Platform{}, _ctx{ctx}
{}
virtual ~VPNPlatform() = default;
@ -169,7 +165,7 @@ namespace llarp::win32
std::shared_ptr<NetworkInterface>
ObtainInterface(InterfaceInfo info, AbstractRouter* router) override
{
return WintunInterface_new(_wintun, std::move(info), router);
return wintun::make_interface(std::move(info), router);
}
std::shared_ptr<I_Packet_IO>
@ -180,7 +176,7 @@ namespace llarp::win32
throw std::invalid_argument{
"cannot create packet io on explicitly specified interface, not currently supported on "
"windows (yet)"};
return m_WinDivert.make_intercepter(
return WinDivert::make_intercepter(
"outbound and ( udp.DstPort == 53 or tcp.DstPort == 53 )",
[router = _ctx->router] { router->TriggerPump(); });
}

@ -8,15 +8,17 @@ namespace llarp::win32
{
auto cat = log::Cat("win32-dll");
}
DLL::DLL(std::string dll) : m_Handle{LoadLibraryA(dll.c_str())}
{
if (not m_Handle)
throw win32::error{fmt::format("failed to load '{}'", dll)};
log::info(cat, "loaded '{}'", dll);
}
DLL::~DLL()
namespace detail
{
FreeLibrary(m_Handle);
}
HMODULE
load_dll(const std::string& dll)
{
auto handle = LoadLibraryExA(dll.c_str(), NULL, LOAD_LIBRARY_SEARCH_APPLICATION_DIR);
if (not handle)
throw win32::error{fmt::format("failed to load '{}'", dll)};
log::info(cat, "loaded '{}'", dll);
return handle;
}
} // namespace detail
} // namespace llarp::win32

@ -5,26 +5,30 @@
namespace llarp::win32
{
class DLL
namespace detail
{
const HMODULE m_Handle;
HMODULE
load_dll(const std::string& dll);
protected:
/// given a name of a function pointer find it and put it into `func`
/// throws if the function does not exist in the DLL we openned.
template <typename Func_t>
template <typename Func, typename... More>
void
init(std::string name, Func_t*& func)
load_funcs(HMODULE handle, const std::string& name, Func*& f, More&&... more)
{
auto ptr = GetProcAddress(m_Handle, name.c_str());
if (not ptr)
if (auto ptr = GetProcAddress(handle, name.c_str()))
f = reinterpret_cast<Func*>(ptr);
else
throw win32::error{fmt::format("function '{}' not found", name)};
func = reinterpret_cast<Func_t*>(ptr);
if constexpr (sizeof...(More) > 0)
load_funcs(handle, std::forward<More>(more)...);
}
} // namespace detail
public:
DLL(std::string dll);
virtual ~DLL();
};
// Loads a DLL and extracts function pointers from it. Takes the dll name and pairs of
// name/function pointer arguments. Throws on failure.
template <typename Func, typename... More>
void
load_dll_functions(const std::string& dll, const std::string& fname, Func*& f, More&&... funcs)
{
detail::load_funcs(detail::load_dll(dll), fname, f, std::forward<More>(funcs)...);
}
} // namespace llarp::win32

@ -19,183 +19,186 @@ namespace llarp::win32
{
auto cat = L::Cat("windivert");
}
using WD_Open_Func_t = decltype(&::WinDivertOpen);
using WD_Close_Func_t = decltype(&::WinDivertClose);
using WD_Shutdown_Func_t = decltype(&::WinDivertShutdown);
using WD_Send_Func_t = decltype(&::WinDivertSend);
using WD_Recv_Func_t = decltype(&::WinDivertRecv);
using WD_IP4_Format_Func_t = decltype(&::WinDivertHelperFormatIPv4Address);
using WD_IP6_Format_Func_t = decltype(&::WinDivertHelperFormatIPv6Address);
struct WinDivertDLL : DLL
namespace wd
{
WD_Open_Func_t open;
WD_Close_Func_t close;
WD_Shutdown_Func_t shutdown;
WD_Send_Func_t send;
WD_Recv_Func_t recv;
WD_IP4_Format_Func_t format_ip4;
WD_IP6_Format_Func_t format_ip6;
WinDivertDLL() : DLL{"WinDivert.dll"}
namespace
{
init("WinDivertOpen", open);
init("WinDivertClose", close);
init("WinDivertShutdown", shutdown);
init("WinDivertSend", send);
init("WinDivertRecv", recv);
init("WinDivertHelperFormatIPv4Address", format_ip4);
init("WinDivertHelperFormatIPv6Address", format_ip6);
L::debug(cat, "loaded windivert functions");
}
virtual ~WinDivertDLL() = default;
};
decltype(::WinDivertOpen)* open = nullptr;
decltype(::WinDivertClose)* close = nullptr;
decltype(::WinDivertShutdown)* shutdown = nullptr;
decltype(::WinDivertSend)* send = nullptr;
decltype(::WinDivertRecv)* recv = nullptr;
decltype(::WinDivertHelperFormatIPv4Address)* format_ip4 = nullptr;
decltype(::WinDivertHelperFormatIPv6Address)* format_ip6 = nullptr;
struct WD_Packet
{
std::vector<byte_t> pkt;
WINDIVERT_ADDRESS addr;
};
void
Initialize()
{
if (wd::open)
return;
// clang-format off
load_dll_functions(
"WinDivert.dll",
"WinDivertOpen", open,
"WinDivertClose", close,
"WinDivertShutdown", shutdown,
"WinDivertSend", send,
"WinDivertRecv", recv,
"WinDivertHelperFormatIPv4Address", format_ip4,
"WinDivertHelperFormatIPv6Address", format_ip6);
// clang-format on
}
} // namespace
class WinDivert_IO : public llarp::vpn::I_Packet_IO
{
const std::shared_ptr<WinDivertDLL> m_WinDivert;
std::function<void(void)> m_Wake;
HANDLE m_Handle;
std::thread m_Runner;
thread::Queue<WD_Packet> m_RecvQueue;
// dns packet queue size
static constexpr size_t recv_queue_size = 64;
public:
WinDivert_IO(
std::shared_ptr<WinDivertDLL> api, std::string filter_spec, std::function<void(void)> wake)
: m_WinDivert{api}, m_Wake{wake}, m_RecvQueue{recv_queue_size}
struct Packet
{
L::info(cat, "load windivert with filterspec: '{}'", filter_spec);
m_Handle = m_WinDivert->open(filter_spec.c_str(), WINDIVERT_LAYER_NETWORK, 0, 0);
if (auto err = GetLastError())
throw win32::error{err, "cannot open windivert handle"};
}
std::vector<byte_t> pkt;
WINDIVERT_ADDRESS addr;
};
~WinDivert_IO()
class IO : public llarp::vpn::I_Packet_IO
{
m_WinDivert->close(m_Handle);
}
std::function<void(void)> m_Wake;
std::optional<WD_Packet>
recv_packet() const
{
WINDIVERT_ADDRESS addr{};
std::vector<byte_t> pkt;
pkt.resize(1500); // net::IPPacket::MaxSize
UINT sz{};
if (not m_WinDivert->recv(m_Handle, pkt.data(), pkt.size(), &sz, &addr))
HANDLE m_Handle;
std::thread m_Runner;
thread::Queue<Packet> m_RecvQueue;
// dns packet queue size
static constexpr size_t recv_queue_size = 64;
public:
IO(std::string filter_spec, std::function<void(void)> wake)
: m_Wake{wake}, m_RecvQueue{recv_queue_size}
{
auto err = GetLastError();
if (err and err != ERROR_BROKEN_PIPE)
throw win32::error{
err, fmt::format("failed to receive packet from windivert (code={})", err)};
else if (err)
SetLastError(0);
return std::nullopt;
wd::Initialize();
L::info(cat, "load windivert with filterspec: '{}'", filter_spec);
m_Handle = wd::open(filter_spec.c_str(), WINDIVERT_LAYER_NETWORK, 0, 0);
if (auto err = GetLastError())
throw win32::error{err, "cannot open windivert handle"};
}
L::info(cat, "got packet of size {}B", sz);
pkt.resize(sz);
return WD_Packet{std::move(pkt), std::move(addr)};
}
void
send_packet(const WD_Packet& w_pkt) const
{
const auto& pkt = w_pkt.pkt;
const auto* addr = &w_pkt.addr;
L::info(cat, "send dns packet of size {}B", pkt.size());
UINT sz{};
if (m_WinDivert->send(m_Handle, pkt.data(), pkt.size(), &sz, addr))
return;
throw win32::error{"windivert send failed"};
}
~IO()
{
wd::close(m_Handle);
}
virtual int
PollFD() const
{
return -1;
}
std::optional<Packet>
recv_packet() const
{
WINDIVERT_ADDRESS addr{};
std::vector<byte_t> pkt;
pkt.resize(1500); // net::IPPacket::MaxSize
UINT sz{};
if (not wd::recv(m_Handle, pkt.data(), pkt.size(), &sz, &addr))
{
auto err = GetLastError();
if (err and err != ERROR_BROKEN_PIPE)
throw win32::error{
err, fmt::format("failed to receive packet from windivert (code={})", err)};
else if (err)
SetLastError(0);
return std::nullopt;
}
L::info(cat, "got packet of size {}B", sz);
pkt.resize(sz);
return Packet{std::move(pkt), std::move(addr)};
}
virtual bool WritePacket(net::IPPacket) override
{
return false;
}
void
send_packet(const Packet& w_pkt) const
{
const auto& pkt = w_pkt.pkt;
const auto* addr = &w_pkt.addr;
L::info(cat, "send dns packet of size {}B", pkt.size());
UINT sz{};
if (wd::send(m_Handle, pkt.data(), pkt.size(), &sz, addr))
return;
throw win32::error{"windivert send failed"};
}
virtual net::IPPacket
ReadNextPacket() override
{
auto w_pkt = m_RecvQueue.tryPopFront();
if (not w_pkt)
return net::IPPacket{};
net::IPPacket pkt{std::move(w_pkt->pkt)};
pkt.reply = [this, addr = std::move(w_pkt->addr)](auto pkt) {
send_packet(WD_Packet{pkt.steal(), addr});
};
return pkt;
}
virtual int
PollFD() const
{
return -1;
}
virtual void
Start() override
{
L::info(cat, "starting windivert");
if (m_Runner.joinable())
throw std::runtime_error{"windivert thread is already running"};
virtual bool
WritePacket(net::IPPacket) override
{
return false;
}
auto read_loop = [this]() {
log::info(cat, "windivert read loop start");
while (true)
{
// in the read loop, read packets until they stop coming in
// each packet is sent off
if (auto maybe_pkt = recv_packet())
virtual net::IPPacket
ReadNextPacket() override
{
auto w_pkt = m_RecvQueue.tryPopFront();
if (not w_pkt)
return net::IPPacket{};
net::IPPacket pkt{std::move(w_pkt->pkt)};
pkt.reply = [this, addr = std::move(w_pkt->addr)](auto pkt) {
send_packet(Packet{pkt.steal(), addr});
};
return pkt;
}
virtual void
Start() override
{
L::info(cat, "starting windivert");
if (m_Runner.joinable())
throw std::runtime_error{"windivert thread is already running"};
auto read_loop = [this]() {
log::info(cat, "windivert read loop start");
while (true)
{
m_RecvQueue.pushBack(std::move(*maybe_pkt));
// wake up event loop
m_Wake();
// in the read loop, read packets until they stop coming in
// each packet is sent off
if (auto maybe_pkt = recv_packet())
{
m_RecvQueue.pushBack(std::move(*maybe_pkt));
// wake up event loop
m_Wake();
}
else // leave loop on read fail
break;
}
else // leave loop on read fail
break;
}
log::info(cat, "windivert read loop end");
};
log::info(cat, "windivert read loop end");
};
m_Runner = std::thread{std::move(read_loop)};
}
m_Runner = std::thread{std::move(read_loop)};
}
virtual void
Stop() override
{
L::info(cat, "stopping windivert");
m_WinDivert->shutdown(m_Handle, WINDIVERT_SHUTDOWN_BOTH);
m_Runner.join();
}
};
virtual void
Stop() override
{
L::info(cat, "stopping windivert");
wd::shutdown(m_Handle, WINDIVERT_SHUTDOWN_BOTH);
m_Runner.join();
}
};
WinDivert_API::WinDivert_API() : m_Impl{std::make_shared<WinDivertDLL>()}
{}
} // namespace wd
std::string
WinDivert_API::format_ip(uint32_t ip) const
namespace WinDivert
{
std::array<char, 128> buf{};
m_Impl->format_ip4(ip, buf.data(), buf.size());
return buf.data();
}
std::string
format_ip(uint32_t ip)
{
std::array<char, 128> buf;
wd::format_ip4(ip, buf.data(), buf.size());
return buf.data();
}
std::shared_ptr<llarp::vpn::I_Packet_IO>
make_intercepter(std::string filter_spec, std::function<void(void)> wake)
{
return std::make_shared<wd::IO>(filter_spec, wake);
}
} // namespace WinDivert
std::shared_ptr<llarp::vpn::I_Packet_IO>
WinDivert_API::make_intercepter(std::string filter_spec, std::function<void(void)> wake) const
{
return std::make_shared<WinDivert_IO>(m_Impl, filter_spec, wake);
}
} // namespace llarp::win32

@ -4,27 +4,18 @@
#include <windows.h>
#include <llarp/vpn/i_packet_io.hpp>
namespace llarp::win32
namespace llarp::win32::WinDivert
{
struct WinDivertDLL;
/// format an ipv4 in host order to string such that a windivert filter spec can understand it
std::string
format_ip(uint32_t ip);
class WinDivert_API
{
std::shared_ptr<WinDivertDLL> m_Impl;
/// create a packet intercepter that uses windivert.
/// filter_spec describes the kind of traffic we wish to intercept.
/// pass in a callable that wakes up the main event loop.
/// we hide all implementation details from other compilation units to prevent issues with
/// linkage that may arrise.
std::shared_ptr<llarp::vpn::I_Packet_IO>
make_intercepter(std::string filter_spec, std::function<void(void)> wakeup);
public:
WinDivert_API();
/// format an ipv4 in host order to string such that a windivert filter spec can understand it
std::string
format_ip(uint32_t ip) const;
/// create a packet intercepter that uses windivert.
/// filter_spec describes the kind of traffic we wish to intercept.
/// pass in a callable that wakes up the main event loop.
/// we hide all implementation details from other compilation units to prevent issues with
/// linkage that may arrise.
std::shared_ptr<llarp::vpn::I_Packet_IO>
make_intercepter(std::string filter_spec, std::function<void(void)> wakeup) const;
};
} // namespace llarp::win32
} // namespace llarp::win32::WinDivert

@ -22,78 +22,74 @@ namespace llarp::win32
{
auto logcat = log::Cat("wintun");
constexpr auto PoolName = "lokinet";
} // namespace
using Adapter_ptr = std::shared_ptr<_WINTUN_ADAPTER>;
WINTUN_CREATE_ADAPTER_FUNC* create_adapter = nullptr;
WINTUN_CLOSE_ADAPTER_FUNC* close_adapter = nullptr;
WINTUN_OPEN_ADAPTER_FUNC* open_adapter = nullptr;
WINTUN_GET_ADAPTER_LUID_FUNC* get_adapter_LUID = nullptr;
WINTUN_GET_RUNNING_DRIVER_VERSION_FUNC* get_version = nullptr;
WINTUN_DELETE_DRIVER_FUNC* delete_driver = nullptr;
WINTUN_SET_LOGGER_FUNC* set_logger = nullptr;
WINTUN_START_SESSION_FUNC* start_session = nullptr;
WINTUN_END_SESSION_FUNC* end_session = nullptr;
WINTUN_GET_READ_WAIT_EVENT_FUNC* get_adapter_handle = nullptr;
WINTUN_RECEIVE_PACKET_FUNC* read_packet = nullptr;
WINTUN_RELEASE_RECEIVE_PACKET_FUNC* release_read = nullptr;
WINTUN_ALLOCATE_SEND_PACKET_FUNC* alloc_write = nullptr;
WINTUN_SEND_PACKET_FUNC* send_packet = nullptr;
struct PacketWrapper
{
BYTE* data;
DWORD size;
WINTUN_SESSION_HANDLE session;
WINTUN_RELEASE_RECEIVE_PACKET_FUNC* release;
/// copy our data into an ip packet struct
net::IPPacket
copy() const
void
WintunInitialize()
{
net::IPPacket pkt{size};
std::copy_n(data, size, pkt.data());
return pkt;
}
if (create_adapter)
return;
~PacketWrapper()
{
release(session, data);
// clang-format off
load_dll_functions(
"wintun.dll",
"WintunCreateAdapter", create_adapter,
"WintunCloseAdapter", close_adapter,
"WintunOpenAdapter", open_adapter,
"WintunGetAdapterLUID", get_adapter_LUID,
"WintunGetRunningDriverVersion", get_version,
"WintunDeleteDriver", delete_driver,
"WintunSetLogger", set_logger,
"WintunStartSession", start_session,
"WintunEndSession", end_session,
"WintunGetReadWaitEvent", get_adapter_handle,
"WintunReceivePacket", read_packet,
"WintunReleaseReceivePacket", release_read,
"WintunAllocateSendPacket", alloc_write,
"WintunSendPacket", send_packet);
// clang-format on
}
};
class WintunDLL : public DLL
{
public:
WINTUN_CREATE_ADAPTER_FUNC* create_adapter;
WINTUN_OPEN_ADAPTER_FUNC* open_adapter;
WINTUN_CLOSE_ADAPTER_FUNC* close_adapter;
WINTUN_START_SESSION_FUNC* start_session;
WINTUN_END_SESSION_FUNC* end_session;
WINTUN_GET_ADAPTER_LUID_FUNC* get_adapter_LUID;
WINTUN_GET_READ_WAIT_EVENT_FUNC* get_adapter_handle;
WINTUN_RECEIVE_PACKET_FUNC* read_packet;
WINTUN_RELEASE_RECEIVE_PACKET_FUNC* release_read;
WINTUN_ALLOCATE_SEND_PACKET_FUNC* alloc_write;
WINTUN_SEND_PACKET_FUNC* write_packet;
using Adapter_ptr = std::shared_ptr<_WINTUN_ADAPTER>;
WINTUN_SET_LOGGER_FUNC* set_logger;
WINTUN_GET_RUNNING_DRIVER_VERSION_FUNC* get_version;
/// read out all the wintun function pointers from a library handle
WintunDLL() : DLL{"wintun.dll"}
struct PacketWrapper
{
init("WintunGetRunningDriverVersion", get_version);
init("WintunCreateAdapter", create_adapter);
init("WintunOpenAdapter", open_adapter);
init("WintunCloseAdapter", close_adapter);
init("WintunStartSession", start_session);
init("WintunEndSession", end_session);
init("WintunGetAdapterLUID", get_adapter_LUID);
init("WintunReceivePacket", read_packet);
init("WintunReleaseReceivePacket", release_read);
init("WintunSendPacket", write_packet);
init("WintunAllocateSendPacket", alloc_write);
init("WintunSetLogger", set_logger);
init("WintunGetReadWaitEvent", get_adapter_handle);
if (auto wintun_ver = get_version())
log::info(logcat, fmt::format("wintun version {0:x} loaded", wintun_ver));
else
throw win32::error{"Failed to load wintun"};
}
BYTE* data;
DWORD size;
WINTUN_SESSION_HANDLE session;
/// copy our data into an ip packet struct
net::IPPacket
copy() const
{
net::IPPacket pkt{size};
std::copy_n(data, size, pkt.data());
return pkt;
}
~PacketWrapper()
{
release_read(session, data);
}
};
/// autovivify a wintun adapter handle
[[nodiscard]] auto
make_adapter(std::string adapter_name, std::string tunnel_name) const
make_adapter(std::string adapter_name, std::string tunnel_name)
{
auto adapter_name_wide = to_wide(adapter_name);
if (auto _impl = open_adapter(adapter_name_wide.c_str()))
@ -112,303 +108,286 @@ namespace llarp::win32
log::info(logcat, "creating adapter: '{}' on pool '{}'", adapter_name, tunnel_name);
auto tunnel_name_wide = to_wide(tunnel_name);
if (auto _impl = create_adapter(adapter_name_wide.c_str(), tunnel_name_wide.c_str(), &guid))
{
if (auto v = get_version())
log::info(logcat, "created adapter (wintun v{}.{})", (v >> 16) & 0xff, v & 0xff);
else
log::warning(
logcat,
"failed to query wintun driver version: {}!",
error_to_string(GetLastError()));
return _impl;
throw win32::error{"failed to create wintun adapter"};
}
};
class WintunAdapter
{
WINTUN_CLOSE_ADAPTER_FUNC* _close_adapter;
WINTUN_GET_ADAPTER_LUID_FUNC* _get_adapter_LUID;
WINTUN_GET_READ_WAIT_EVENT_FUNC* _get_handle;
WINTUN_START_SESSION_FUNC* _start_session;
WINTUN_END_SESSION_FUNC* _end_session;
WINTUN_ADAPTER_HANDLE _handle;
}
[[nodiscard]] auto
get_adapter_LUID() const
{
NET_LUID _uid{};
_get_adapter_LUID(_handle, &_uid);
return _uid;
throw win32::error{"failed to create wintun adapter"};
}
public:
WintunAdapter(const WintunDLL& dll, std::string name)
: _close_adapter{dll.close_adapter}
, _get_adapter_LUID{dll.get_adapter_LUID}
, _get_handle{dll.get_adapter_handle}
, _start_session{dll.start_session}
, _end_session{dll.end_session}
class WintunAdapter
{
_handle = dll.make_adapter(std::move(name), PoolName);
if (_handle == nullptr)
throw std::runtime_error{"failed to create wintun adapter"};
}
WINTUN_ADAPTER_HANDLE _handle;
/// put adapter up
void
Up(const vpn::InterfaceInfo& info) const
{
const auto luid = get_adapter_LUID();
for (const auto& addr : info.addrs)
[[nodiscard]] auto
GetAdapterLUID() const
{
// TODO: implement ipv6
if (addr.fam != AF_INET)
continue;
MIB_UNICASTIPADDRESS_ROW AddressRow;
InitializeUnicastIpAddressEntry(&AddressRow);
AddressRow.InterfaceLuid = luid;
AddressRow.Address.Ipv4.sin_family = AF_INET;
AddressRow.Address.Ipv4.sin_addr.S_un.S_addr = ToNet(net::TruncateV6(addr.range.addr)).n;
AddressRow.OnLinkPrefixLength = addr.range.HostmaskBits();
AddressRow.DadState = IpDadStatePreferred;
if (auto err = CreateUnicastIpAddressEntry(&AddressRow); err != ERROR_SUCCESS)
throw error{err, fmt::format("cannot set address '{}'", addr.range)};
log::info(logcat, "added address: '{}'", addr.range);
NET_LUID _uid{};
get_adapter_LUID(_handle, &_uid);
return _uid;
}
}
/// put adapter down and close it
void
Down() const
{
_close_adapter(_handle);
}
public:
explicit WintunAdapter(std::string name)
{
_handle = make_adapter(std::move(name), PoolName);
if (_handle == nullptr)
throw std::runtime_error{"failed to create wintun adapter"};
}
/// auto vivify a wintun session handle and read handle off of our adapter
[[nodiscard]] std::pair<WINTUN_SESSION_HANDLE, HANDLE>
session() const
{
if (auto impl = _start_session(_handle, WINTUN_MAX_RING_CAPACITY))
/// put adapter up
void
Up(const vpn::InterfaceInfo& info) const
{
if (auto handle = _get_handle(impl))
return {impl, handle};
_end_session(impl);
const auto luid = GetAdapterLUID();
for (const auto& addr : info.addrs)
{
// TODO: implement ipv6
if (addr.fam != AF_INET)
continue;
MIB_UNICASTIPADDRESS_ROW AddressRow;
InitializeUnicastIpAddressEntry(&AddressRow);
AddressRow.InterfaceLuid = luid;
AddressRow.Address.Ipv4.sin_family = AF_INET;
AddressRow.Address.Ipv4.sin_addr.S_un.S_addr = ToNet(net::TruncateV6(addr.range.addr)).n;
AddressRow.OnLinkPrefixLength = addr.range.HostmaskBits();
AddressRow.DadState = IpDadStatePreferred;
if (auto err = CreateUnicastIpAddressEntry(&AddressRow); err != ERROR_SUCCESS)
throw win32::error{err, fmt::format("cannot set address '{}'", addr.range)};
LogDebug(fmt::format("added address: '{}'", addr.range));
}
}
return {nullptr, nullptr};
}
};
class WintunSession
{
WINTUN_END_SESSION_FUNC* _end_session;
WINTUN_RECEIVE_PACKET_FUNC* _recv_pkt;
WINTUN_RELEASE_RECEIVE_PACKET_FUNC* _release_pkt;
WINTUN_ALLOCATE_SEND_PACKET_FUNC* _alloc_write;
WINTUN_SEND_PACKET_FUNC* _write_pkt;
WINTUN_SESSION_HANDLE _impl;
HANDLE _handle;
public:
WintunSession(const WintunDLL& dll)
: _end_session{dll.end_session}
, _recv_pkt{dll.read_packet}
, _release_pkt{dll.release_read}
, _alloc_write{dll.alloc_write}
, _write_pkt{dll.write_packet}
, _impl{nullptr}
, _handle{nullptr}
{}
/// put adapter down and close it
void
Down() const
{
close_adapter(_handle);
}
void
Start(const std::shared_ptr<WintunAdapter>& adapter)
{
if (auto [impl, handle] = adapter->session(); impl and handle)
/// auto vivify a wintun session handle and read handle off of our adapter
[[nodiscard]] std::pair<WINTUN_SESSION_HANDLE, HANDLE>
session() const
{
_impl = impl;
_handle = handle;
return;
if (auto wintun_ver = get_version())
log::info(
logcat,
fmt::format(
"wintun version {}.{} loaded", (wintun_ver >> 16) & 0xff, wintun_ver & 0xff));
else
throw win32::error{"Failed to load wintun"};
if (auto impl = start_session(_handle, WINTUN_MAX_RING_CAPACITY))
{
if (auto handle = get_adapter_handle(impl))
return {impl, handle};
end_session(impl);
}
return {nullptr, nullptr};
}
throw error{GetLastError(), "could not create wintun session"};
}
};
void
Stop() const
class WintunSession
{
_end_session(_impl);
}
WINTUN_SESSION_HANDLE _impl;
HANDLE _handle;
void
WaitFor(std::chrono::milliseconds dur)
{
WaitForSingleObject(_handle, dur.count());
}
public:
WintunSession() : _impl{nullptr}, _handle{nullptr}
{}
/// read a unique pointer holding a packet read from wintun, returns the packet if we read one
/// and a bool, set to true if our adapter is now closed
[[nodiscard]] std::pair<std::unique_ptr<PacketWrapper>, bool>
ReadPacket() const
{
// typedef so the return statement fits on 1 line :^D
using Pkt_ptr = std::unique_ptr<PacketWrapper>;
DWORD sz;
if (auto* ptr = _recv_pkt(_impl, &sz))
return {Pkt_ptr{new PacketWrapper{ptr, sz, _impl, _release_pkt}}, false};
const auto err = GetLastError();
if (err == ERROR_NO_MORE_ITEMS or err == ERROR_HANDLE_EOF)
void
Start(const std::shared_ptr<WintunAdapter>& adapter)
{
SetLastError(0);
return {nullptr, err == ERROR_HANDLE_EOF};
if (auto [impl, handle] = adapter->session(); impl and handle)
{
_impl = impl;
_handle = handle;
return;
}
throw error{GetLastError(), "could not create wintun session"};
}
throw error{err, "failed to read packet"};
}
/// write an ip packet to the interface, return 2 bools, first is did we write the packet,
/// second if we are terminating
std::pair<bool, bool>
WritePacket(net::IPPacket pkt) const
{
if (auto* buf = _alloc_write(_impl, pkt.size()))
void
Stop() const
{
std::copy_n(pkt.data(), pkt.size(), buf);
_write_pkt(_impl, buf);
return {true, false};
end_session(_impl);
}
const auto err = GetLastError();
if (err == ERROR_BUFFER_OVERFLOW or err == ERROR_HANDLE_EOF)
void
WaitFor(std::chrono::milliseconds dur)
{
SetLastError(0);
return {err != ERROR_BUFFER_OVERFLOW, err == ERROR_HANDLE_EOF};
WaitForSingleObject(_handle, dur.count());
}
throw error{err, "failed to write packet"};
}
};
class WintunInterface : public vpn::NetworkInterface
{
AbstractRouter* const _router;
std::shared_ptr<WintunAdapter> _adapter;
std::shared_ptr<WintunSession> _session;
thread::Queue<net::IPPacket> _recv_queue;
thread::Queue<net::IPPacket> _send_queue;
std::thread _recv_thread;
std::thread _send_thread;
static inline constexpr size_t packet_queue_length = 1024;
public:
WintunInterface(const WintunDLL& dll, vpn::InterfaceInfo info, AbstractRouter* router)
: vpn::NetworkInterface{std::move(info)}
, _router{router}
, _adapter{std::make_shared<WintunAdapter>(dll, m_Info.ifname)}
, _session{std::make_shared<WintunSession>(dll)}
, _recv_queue{packet_queue_length}
, _send_queue{packet_queue_length}
{}
/// read a unique pointer holding a packet read from wintun, returns the packet if we read
/// one and a bool, set to true if our adapter is now closed
[[nodiscard]] std::pair<std::unique_ptr<PacketWrapper>, bool>
ReadPacket() const
{
// typedef so the return statement fits on 1 line :^D
using Pkt_ptr = std::unique_ptr<PacketWrapper>;
DWORD sz;
if (auto* ptr = read_packet(_impl, &sz))
return {Pkt_ptr{new PacketWrapper{ptr, sz, _impl}}, false};
const auto err = GetLastError();
if (err == ERROR_NO_MORE_ITEMS or err == ERROR_HANDLE_EOF)
{
SetLastError(0);
return {nullptr, err == ERROR_HANDLE_EOF};
}
throw error{err, "failed to read packet"};
}
void
Start() override
{
m_Info.index = 0;
// put the adapter and set addresses
_adapter->Up(m_Info);
// start up io session
_session->Start(_adapter);
// start read packet loop
_recv_thread = std::thread{[session = _session, this]() {
do
/// write an ip packet to the interface, return 2 bools, first is did we write the packet,
/// second if we are terminating
std::pair<bool, bool>
WritePacket(net::IPPacket pkt) const
{
if (auto* buf = alloc_write(_impl, pkt.size()))
{
// read all our packets this iteration
bool more{true};
do
{
auto [pkt, done] = session->ReadPacket();
// bail if we are closing
if (done)
return;
if (pkt)
_recv_queue.pushBack(pkt->copy());
else
more = false;
} while (more);
// wait for more packets
session->WaitFor(5s);
} while (true);
}};
// start write packet loop
_send_thread = std::thread{[this, session = _session]() {
do
std::copy_n(pkt.data(), pkt.size(), buf);
send_packet(_impl, buf);
return {true, false};
}
const auto err = GetLastError();
if (err == ERROR_BUFFER_OVERFLOW or err == ERROR_HANDLE_EOF)
{
if (auto maybe = _send_queue.popFrontWithTimeout(100ms))
{
auto [written, done] = session->WritePacket(std::move(*maybe));
if (done)
return;
}
} while (_send_queue.enabled());
}};
}
SetLastError(0);
return {err != ERROR_BUFFER_OVERFLOW, err == ERROR_HANDLE_EOF};
}
throw error{err, "failed to write packet"};
}
};
void
Stop() override
class WintunInterface : public vpn::NetworkInterface
{
// end writing packets
_send_queue.disable();
_send_thread.join();
// end reading packets
_session->Stop();
_recv_thread.join();
// close session
_session.reset();
// put adapter down
_adapter->Down();
_adapter.reset();
}
AbstractRouter* const _router;
std::shared_ptr<WintunAdapter> _adapter;
std::shared_ptr<WintunSession> _session;
thread::Queue<net::IPPacket> _recv_queue;
thread::Queue<net::IPPacket> _send_queue;
std::thread _recv_thread;
std::thread _send_thread;
static inline constexpr size_t packet_queue_length = 1024;
public:
WintunInterface(vpn::InterfaceInfo info, AbstractRouter* router)
: vpn::NetworkInterface{std::move(info)}
, _router{router}
, _adapter{std::make_shared<WintunAdapter>(m_Info.ifname)}
, _session{std::make_shared<WintunSession>()}
, _recv_queue{packet_queue_length}
, _send_queue{packet_queue_length}
{}
void
Start() override
{
m_Info.index = 0;
// put the adapter and set addresses
_adapter->Up(m_Info);
// start up io session
_session->Start(_adapter);
// start read packet loop
_recv_thread = std::thread{[session = _session, this]() {
do
{
// read all our packets this iteration
bool more{true};
do
{
auto [pkt, done] = session->ReadPacket();
// bail if we are closing
if (done)
return;
if (pkt)
_recv_queue.pushBack(pkt->copy());
else
more = false;
} while (more);
// wait for more packets
session->WaitFor(5s);
} while (true);
}};
// start write packet loop
_send_thread = std::thread{[this, session = _session]() {
do
{
if (auto maybe = _send_queue.popFrontWithTimeout(100ms))
{
auto [written, done] = session->WritePacket(std::move(*maybe));
if (done)
return;
}
} while (_send_queue.enabled());
}};
}
net::IPPacket
ReadNextPacket() override
{
net::IPPacket pkt{};
if (auto maybe_pkt = _recv_queue.tryPopFront())
pkt = std::move(*maybe_pkt);
return pkt;
}
void
Stop() override
{
// end writing packets
_send_queue.disable();
_send_thread.join();
// end reading packets
_session->Stop();
_recv_thread.join();
// close session
_session.reset();
// put adapter down
_adapter->Down();
_adapter.reset();
}
bool
WritePacket(net::IPPacket pkt) override
{
return _send_queue.tryPushBack(std::move(pkt)) == thread::QueueReturn::Success;
}
net::IPPacket
ReadNextPacket() override
{
net::IPPacket pkt{};
if (auto maybe_pkt = _recv_queue.tryPopFront())
pkt = std::move(*maybe_pkt);
return pkt;
}
int
PollFD() const override
{
return -1;
}
bool
WritePacket(net::IPPacket pkt) override
{
return _send_queue.tryPushBack(std::move(pkt)) == thread::QueueReturn::Success;
}
void
MaybeWakeUpperLayers() const override
{
_router->TriggerPump();
}
};
int
PollFD() const override
{
return -1;
}
struct WintunContext
{
WintunDLL dll{};
};
void
MaybeWakeUpperLayers() const override
{
_router->TriggerPump();
}
};
} // namespace
std::shared_ptr<WintunContext>
WintunContext_new()
namespace wintun
{
return std::make_shared<WintunContext>();
}
std::shared_ptr<vpn::NetworkInterface>
WintunInterface_new(
std::shared_ptr<llarp::win32::WintunContext> const& ctx,
const llarp::vpn::InterfaceInfo& info,
llarp::AbstractRouter* r)
{
return std::static_pointer_cast<vpn::NetworkInterface>(
std::make_shared<WintunInterface>(ctx->dll, info, r));
}
std::shared_ptr<vpn::NetworkInterface>
make_interface(const llarp::vpn::InterfaceInfo& info, llarp::AbstractRouter* r)
{
WintunInitialize();
return std::static_pointer_cast<vpn::NetworkInterface>(
std::make_shared<WintunInterface>(info, r));
}
} // namespace wintun
} // namespace llarp::win32

@ -2,8 +2,6 @@
#include <memory>
#include <windows.h>
namespace llarp
{
struct AbstractRouter;
@ -15,19 +13,11 @@ namespace llarp::vpn
class NetworkInterface;
} // namespace llarp::vpn
namespace llarp::win32
namespace llarp::win32::wintun
{
/// holds all wintun implementation, including function pointers we fetch out of the wintun
/// library forward declared to hide wintun from other compilation units
class WintunContext;
std::shared_ptr<WintunContext>
WintunContext_new();
/// makes a new vpn interface with a wintun context given info and a router pointer
std::shared_ptr<vpn::NetworkInterface>
WintunInterface_new(
const std::shared_ptr<WintunContext>&,
make_interface(
const vpn::InterfaceInfo& info,
AbstractRouter* router);

Loading…
Cancel
Save