diff --git a/libi2pd/KadDHT.cpp b/libi2pd/KadDHT.cpp index c3905b7a..48486675 100644 --- a/libi2pd/KadDHT.cpp +++ b/libi2pd/KadDHT.cpp @@ -14,7 +14,7 @@ namespace i2p namespace data { DHTNode::DHTNode (): - zero (nullptr), one (nullptr), hash (nullptr) + zero (nullptr), one (nullptr) { } @@ -22,17 +22,16 @@ namespace data { if (zero) delete zero; if (one) delete one; - if (hash) delete hash; } - void DHTNode::MoveHashUp (bool fromOne) + void DHTNode::MoveRouterUp (bool fromOne) { DHTNode *& side = fromOne ? one : zero; if (side) { - if (hash) delete hash; // shouldn't happen - hash = side->hash; - side->hash = nullptr; + if (router) router = nullptr; // shouldn't happen + router = side->router; + side->router = nullptr; delete side; side = nullptr; } @@ -49,38 +48,46 @@ namespace data delete m_Root; } - DHTNode * DHTTable::Insert (const IdentHash& h) + void DHTTable::Clear () { - return Insert (new IdentHash (h), m_Root, 0); + m_Size = 0; + delete m_Root; + m_Root = new DHTNode; + } + + void DHTTable::Insert (const std::shared_ptr& r) + { + if (!r) return; + return Insert (r, m_Root, 0); } - DHTNode * DHTTable::Insert (IdentHash * h, DHTNode * root, int level) + void DHTTable::Insert (const std::shared_ptr& r, DHTNode * root, int level) { - if (root->hash) + if (root->router) { - if (*(root->hash) == *h) + if (root->router->GetIdentHash () == r->GetIdentHash ()) { - delete h; - return root; + root->router = r; // replace + return; } - auto h2 = root->hash; - root->hash = nullptr; m_Size--; + auto r2 = root->router; + root->router = nullptr; m_Size--; int bit1, bit2; do { - bit1 = h->GetBit (level); - bit2 = h2->GetBit (level); + bit1 = r->GetIdentHash ().GetBit (level); + bit2 = r2->GetIdentHash ().GetBit (level); if (bit1 == bit2) { if (bit1) { - if (root->one) return nullptr; // someting wrong + if (root->one) return; // someting wrong root->one = new DHTNode; root = root->one; } else { - if (root->zero) return nullptr; // someting wrong + if (root->zero) return; // someting wrong root->zero = new DHTNode; root = root->zero; } @@ -95,37 +102,36 @@ namespace data root->one = new DHTNode; if (bit1) { - Insert (h2, root->zero, level + 1); - return Insert (h, root->one, level + 1); + Insert (r2, root->zero, level + 1); + Insert (r, root->one, level + 1); } else { - Insert (h2, root->one, level + 1); - return Insert (h, root->zero, level + 1); + Insert (r2, root->one, level + 1); + Insert (r, root->zero, level + 1); } } else { if (!root->zero && !root->one) { - root->hash = h; m_Size++; - return root; + root->router = r; m_Size++; + return; } - int bit = h->GetBit (level); + int bit = r->GetIdentHash ().GetBit (level); if (bit) { if (!root->one) root->one = new DHTNode; - return Insert (h, root->one, level + 1); + Insert (r, root->one, level + 1); } else { if (!root->zero) root->zero = new DHTNode; - return Insert (h, root->zero, level + 1); + Insert (r, root->zero, level + 1); } } - return nullptr; } bool DHTTable::Remove (const IdentHash& h) @@ -137,9 +143,9 @@ namespace data { if (root) { - if (root->hash && *(root->hash) == h) + if (root->router && root->router->GetIdentHash () == h) { - delete root->hash; root->hash = nullptr; + root->router = nullptr; m_Size--; return true; } @@ -152,11 +158,11 @@ namespace data { delete root->one; root->one = nullptr; - if (root->zero && root->zero->hash) - root->MoveHashUp (false); + if (root->zero && root->zero->router) + root->MoveRouterUp (false); } - else if (root->one->hash && !root->zero) - root->MoveHashUp (true); + else if (root->one->router && !root->zero) + root->MoveRouterUp (true); return true; } } @@ -168,11 +174,11 @@ namespace data { delete root->zero; root->zero = nullptr; - if (root->one && root->one->hash) - root->MoveHashUp (true); + if (root->one && root->one->router) + root->MoveRouterUp (true); } - else if (root->zero->hash && !root->one) - root->MoveHashUp (false); + else if (root->zero->router && !root->one) + root->MoveRouterUp (false); return true; } } @@ -180,48 +186,95 @@ namespace data return false; } - IdentHash * DHTTable::FindClosest (const IdentHash& h) + std::shared_ptr DHTTable::FindClosest (const IdentHash& h, const Filter& filter) { - return FindClosest (h, m_Root, 0); + if (filter) m_Filter = filter; + auto r = FindClosest (h, m_Root, 0); + m_Filter = nullptr; + return r; } - IdentHash * DHTTable::FindClosest (const IdentHash& h, DHTNode * root, int level) + std::shared_ptr DHTTable::FindClosest (const IdentHash& h, DHTNode * root, int level) { - if (root->hash) return root->hash; + bool split = false; + do + { + if (root->router) + return (!m_Filter || m_Filter (root->router)) ? root->router : nullptr; + split = root->zero && root->one; + if (!split) + { + if (root->zero) root = root->zero; + else if (root->one) root = root->one; + else return nullptr; + level++; + } + } + while (!split); int bit = h.GetBit (level); if (bit) { if (root->one) - return FindClosest (h, root->one, level + 1); + { + auto r = FindClosest (h, root->one, level + 1); + if (r) return r; + } if (root->zero) - return FindClosest (h, root->zero, level + 1); + { + auto r = FindClosest (h, root->zero, level + 1); + if (r) return r; + } } else { if (root->zero) - return FindClosest (h, root->zero, level + 1); + { + auto r = FindClosest (h, root->zero, level + 1); + if (r) return r; + } if (root->one) - return FindClosest (h, root->one, level + 1); + { + auto r = FindClosest (h, root->one, level + 1); + if (r) return r; + } } return nullptr; } - std::vector DHTTable::FindClosest (const IdentHash& h, size_t num) + std::vector > DHTTable::FindClosest (const IdentHash& h, size_t num, const Filter& filter) { - std::vector vec; + std::vector > vec; if (num > 0) + { + if (filter) m_Filter = filter; FindClosest (h, num, m_Root, 0, vec); + m_Filter = nullptr; + } return vec; } - void DHTTable::FindClosest (const IdentHash& h, size_t num, DHTNode * root, int level, std::vector& hashes) + void DHTTable::FindClosest (const IdentHash& h, size_t num, DHTNode * root, int level, std::vector >& hashes) { if (hashes.size () >= num) return; - if (root->hash) + bool split = false; + do { - hashes.push_back (root->hash); - return; - } + if (root->router) + { + if (!m_Filter || m_Filter (root->router)) + hashes.push_back (root->router); + return; + } + split = root->zero && root->one; + if (!split) + { + if (root->zero) root = root->zero; + else if (root->one) root = root->one; + else return; + level++; + } + } + while (!split); int bit = h.GetBit (level); if (bit) { @@ -238,6 +291,54 @@ namespace data FindClosest (h, num, root->one, level + 1, hashes); } } + + void DHTTable::Cleanup (Filter filter) + { + if (filter) + { + m_Filter = filter; + Cleanup (m_Root); + m_Filter = nullptr; + } + else + Clear (); + } + + void DHTTable::Cleanup (DHTNode * root) + { + if (!root) return; + if (root->router) + { + if (!m_Filter || !m_Filter (root->router)) + { + m_Size--; + root->router = nullptr; + } + return; + } + if (root->zero) + { + Cleanup (root->zero); + if (root->zero->IsEmpty ()) + { + delete root->zero; + root->zero = nullptr; + } + } + if (root->one) + { + Cleanup (root->one); + if (root->one->IsEmpty ()) + { + delete root->one; + root->one = nullptr; + if (root->zero && root->zero->router) + root->MoveRouterUp (false); + } + else if (root->one->router && !root->zero) + root->MoveRouterUp (true); + } + } void DHTTable::Print (std::stringstream& s) { @@ -248,10 +349,10 @@ namespace data { if (!root) return; s << std::string (level, '-'); - if (root->hash) + if (root->router) { if (!root->zero && !root->one) - s << '>' << GetIdentHashAbbreviation (*(root->hash)); + s << '>' << GetIdentHashAbbreviation (root->router->GetIdentHash ()); else s << "error"; } diff --git a/libi2pd/KadDHT.h b/libi2pd/KadDHT.h index eb12aae7..c280a1de 100644 --- a/libi2pd/KadDHT.h +++ b/libi2pd/KadDHT.h @@ -13,7 +13,8 @@ #include #include #include -#include "Identity.h" +#include +#include "RouterInfo.h" // Kademlia DHT (XOR distance) @@ -24,42 +25,48 @@ namespace data struct DHTNode { DHTNode * zero, * one; - IdentHash * hash; + std::shared_ptr router; DHTNode (); ~DHTNode (); - bool IsEmpty () const { return !zero && !one && !hash; }; - void MoveHashUp (bool fromOne); + bool IsEmpty () const { return !zero && !one && !router; }; + void MoveRouterUp (bool fromOne); }; class DHTTable { + typedef std::function&)> Filter; public: DHTTable (); ~DHTTable (); - DHTNode * Insert (const IdentHash& h); + void Insert (const std::shared_ptr& r); bool Remove (const IdentHash& h); - IdentHash * FindClosest (const IdentHash& h); - std::vector FindClosest (const IdentHash& h, size_t num); + std::shared_ptr FindClosest (const IdentHash& h, const Filter& filter = nullptr); + std::vector > FindClosest (const IdentHash& h, size_t num, const Filter& filter = nullptr); void Print (std::stringstream& s); size_t GetSize () const { return m_Size; }; + void Clear (); + void Cleanup (Filter filter); private: - DHTNode * Insert (IdentHash * h, DHTNode * root, int level); // recursive + void Insert (const std::shared_ptr& r, DHTNode * root, int level); // recursive bool Remove (const IdentHash& h, DHTNode * root, int level); - IdentHash * FindClosest (const IdentHash& h, DHTNode * root, int level); - void FindClosest (const IdentHash& h, size_t num, DHTNode * root, int level, std::vector& hashes); + std::shared_ptr FindClosest (const IdentHash& h, DHTNode * root, int level); + void FindClosest (const IdentHash& h, size_t num, DHTNode * root, int level, std::vector >& hashes); + void Cleanup (DHTNode * root); void Print (std::stringstream& s, DHTNode * root, int level); private: DHTNode * m_Root; size_t m_Size; + // transient + Filter m_Filter; }; } } diff --git a/libi2pd/NetDb.cpp b/libi2pd/NetDb.cpp index 5bc9c47e..685189ed 100644 --- a/libi2pd/NetDb.cpp +++ b/libi2pd/NetDb.cpp @@ -1355,15 +1355,12 @@ namespace data } std::shared_ptr NetDb::GetClosestFloodfill (const IdentHash& destination, - const std::set& excluded, bool closeThanUsOnly) const + const std::set& excluded) const { std::shared_ptr r; XORMetric minMetric; IdentHash destKey = CreateRoutingKey (destination); - if (closeThanUsOnly) - minMetric = destKey ^ i2p::context.GetIdentHash (); - else - minMetric.SetMax (); + minMetric.SetMax (); std::unique_lock l(m_FloodfillsMutex); for (const auto& it: m_Floodfills) { diff --git a/libi2pd/NetDb.hpp b/libi2pd/NetDb.hpp index f0315582..192d2644 100644 --- a/libi2pd/NetDb.hpp +++ b/libi2pd/NetDb.hpp @@ -93,7 +93,7 @@ namespace data std::shared_ptr GetHighBandwidthRandomRouter (std::shared_ptr compatibleWith, bool reverse) const; std::shared_ptr GetRandomSSU2PeerTestRouter (bool v4, const std::set& excluded) const; std::shared_ptr GetRandomSSU2Introducer (bool v4, const std::set& excluded) const; - std::shared_ptr GetClosestFloodfill (const IdentHash& destination, const std::set& excluded, bool closeThanUsOnly = false) const; + std::shared_ptr GetClosestFloodfill (const IdentHash& destination, const std::set& excluded) const; std::vector GetClosestFloodfills (const IdentHash& destination, size_t num, std::set& excluded, bool closeThanUsOnly = false) const; std::shared_ptr GetClosestNonFloodfill (const IdentHash& destination, const std::set& excluded) const;