From 72a39609ed2044df6922fe4016601fd96da3af8c Mon Sep 17 00:00:00 2001 From: orignal Date: Sat, 16 Nov 2024 20:56:35 -0500 Subject: [PATCH] moved all transit tunnels code to TransitTunnels class --- libi2pd/TransitTunnel.cpp | 64 +++++++++++++++++++++++++++++++--- libi2pd/TransitTunnel.h | 20 +++++++++-- libi2pd/Tunnel.cpp | 72 ++++++++++++--------------------------- libi2pd/Tunnel.h | 11 +++--- 4 files changed, 102 insertions(+), 65 deletions(-) diff --git a/libi2pd/TransitTunnel.cpp b/libi2pd/TransitTunnel.cpp index 62192e92..edf96c31 100644 --- a/libi2pd/TransitTunnel.cpp +++ b/libi2pd/TransitTunnel.cpp @@ -122,7 +122,16 @@ namespace tunnel } } - void TransitTunnelBuildMsgHandler::HandleShortTransitTunnelBuildMsg (std::shared_ptr&& msg) + void TransitTunnels::Start () + { + } + + void TransitTunnels::Stop () + { + m_TransitTunnels.clear (); + } + + void TransitTunnels::HandleShortTransitTunnelBuildMsg (std::shared_ptr&& msg) { if (!msg) return; uint8_t * buf = msg->GetPayload(); @@ -194,7 +203,7 @@ namespace tunnel layerKey, ivKey, clearText[SHORT_REQUEST_RECORD_FLAG_OFFSET] & TUNNEL_BUILD_RECORD_GATEWAY_FLAG, clearText[SHORT_REQUEST_RECORD_FLAG_OFFSET] & TUNNEL_BUILD_RECORD_ENDPOINT_FLAG); - if (!i2p::tunnel::tunnels.AddTransitTunnel (transitTunnel)) + if (!AddTransitTunnel (transitTunnel)) retCode = 30; } @@ -275,7 +284,7 @@ namespace tunnel } } - bool TransitTunnelBuildMsgHandler::HandleBuildRequestRecords (int num, uint8_t * records, uint8_t * clearText) + bool TransitTunnels::HandleBuildRequestRecords (int num, uint8_t * records, uint8_t * clearText) { for (int i = 0; i < num; i++) { @@ -324,7 +333,7 @@ namespace tunnel clearText + ECIES_BUILD_REQUEST_RECORD_IV_KEY_OFFSET, clearText[ECIES_BUILD_REQUEST_RECORD_FLAG_OFFSET] & TUNNEL_BUILD_RECORD_GATEWAY_FLAG, clearText[ECIES_BUILD_REQUEST_RECORD_FLAG_OFFSET] & TUNNEL_BUILD_RECORD_ENDPOINT_FLAG); - if (!i2p::tunnel::tunnels.AddTransitTunnel (transitTunnel)) + if (!AddTransitTunnel (transitTunnel)) retCode = 30; } else @@ -362,7 +371,7 @@ namespace tunnel return false; } - void TransitTunnelBuildMsgHandler::HandleVariableTransitTunnelBuildMsg (std::shared_ptr&& msg) + void TransitTunnels::HandleVariableTransitTunnelBuildMsg (std::shared_ptr&& msg) { if (!msg) return; uint8_t * buf = msg->GetPayload(); @@ -396,5 +405,50 @@ namespace tunnel bufbe32toh (clearText + ECIES_BUILD_REQUEST_RECORD_SEND_MSG_ID_OFFSET))); } } + + bool TransitTunnels::AddTransitTunnel (std::shared_ptr tunnel) + { + if (tunnels.AddTunnel (tunnel)) + m_TransitTunnels.push_back (tunnel); + else + { + LogPrint (eLogError, "TransitTunnel: Tunnel with id ", tunnel->GetTunnelID (), " already exists"); + return false; + } + return true; + } + + void TransitTunnels::ManageTransitTunnels (uint64_t ts) + { + for (auto it = m_TransitTunnels.begin (); it != m_TransitTunnels.end ();) + { + auto tunnel = *it; + if (ts > tunnel->GetCreationTime () + TUNNEL_EXPIRATION_TIMEOUT || + ts + TUNNEL_EXPIRATION_TIMEOUT < tunnel->GetCreationTime ()) + { + LogPrint (eLogDebug, "TransitTunnel: Transit tunnel with id ", tunnel->GetTunnelID (), " expired"); + tunnels.RemoveTunnel (tunnel->GetTunnelID ()); + it = m_TransitTunnels.erase (it); + } + else + { + tunnel->Cleanup (); + it++; + } + } + } + + int TransitTunnels::GetTransitTunnelsExpirationTimeout () + { + int timeout = 0; + uint32_t ts = i2p::util::GetSecondsSinceEpoch (); + // TODO: possible race condition with I2PControl + for (const auto& it : m_TransitTunnels) + { + int t = it->GetCreationTime () + TUNNEL_EXPIRATION_TIMEOUT - ts; + if (t > timeout) timeout = t; + } + return timeout; + } } } diff --git a/libi2pd/TransitTunnel.h b/libi2pd/TransitTunnel.h index fb5589dc..fc8676d0 100644 --- a/libi2pd/TransitTunnel.h +++ b/libi2pd/TransitTunnel.h @@ -109,19 +109,33 @@ namespace tunnel const i2p::crypto::AESKey& layerKey, const i2p::crypto::AESKey& ivKey, bool isGateway, bool isEndpoint); - class TransitTunnelBuildMsgHandler + class TransitTunnels { public: - void Start () {}; - void Stop () {}; + void Start (); + void Stop (); + void ManageTransitTunnels (uint64_t ts); + + size_t GetNumTransitTunnels () const { return m_TransitTunnels.size (); } + int GetTransitTunnelsExpirationTimeout (); void HandleShortTransitTunnelBuildMsg (std::shared_ptr&& msg); void HandleVariableTransitTunnelBuildMsg (std::shared_ptr&& msg); private: + bool AddTransitTunnel (std::shared_ptr tunnel); bool HandleBuildRequestRecords (int num, uint8_t * records, uint8_t * clearText); + + private: + + std::list > m_TransitTunnels; + + public: + + // for HTTP only + auto& GetTransitTunnels () const { return m_TransitTunnels; }; }; } } diff --git a/libi2pd/Tunnel.cpp b/libi2pd/Tunnel.cpp index c743691e..d809e48a 100644 --- a/libi2pd/Tunnel.cpp +++ b/libi2pd/Tunnel.cpp @@ -379,6 +379,17 @@ namespace tunnel return nullptr; } + bool Tunnels::AddTunnel (std::shared_ptr tunnel) + { + if (!tunnel) return false; + return m_Tunnels.emplace (tunnel->GetTunnelID (), tunnel).second; + } + + void Tunnels::RemoveTunnel (uint32_t tunnelID) + { + m_Tunnels.erase (tunnelID); + } + std::shared_ptr Tunnels::GetPendingInboundTunnel (uint32_t replyMsgID) { return GetPendingTunnel (replyMsgID, m_PendingInboundTunnels); @@ -466,28 +477,16 @@ namespace tunnel } } - bool Tunnels::AddTransitTunnel (std::shared_ptr tunnel) - { - if (m_Tunnels.emplace (tunnel->GetTunnelID (), tunnel).second) - m_TransitTunnels.push_back (tunnel); - else - { - LogPrint (eLogError, "Tunnel: Tunnel with id ", tunnel->GetTunnelID (), " already exists"); - return false; - } - return true; - } - void Tunnels::Start () { m_IsRunning = true; m_Thread = new std::thread (std::bind (&Tunnels::Run, this)); - m_TransitTunnelBuildMsgHandler.Start (); + m_TransitTunnels.Start (); } void Tunnels::Stop () { - m_TransitTunnelBuildMsgHandler.Stop (); + m_TransitTunnels.Stop (); m_IsRunning = false; m_Queue.WakeUp (); if (m_Thread) @@ -656,7 +655,7 @@ namespace tunnel return; } else - m_TransitTunnelBuildMsgHandler.HandleShortTransitTunnelBuildMsg (std::move (msg)); + m_TransitTunnels.HandleShortTransitTunnelBuildMsg (std::move (msg)); } void Tunnels::HandleVariableTunnelBuildMsg (std::shared_ptr msg) @@ -679,7 +678,7 @@ namespace tunnel } } else - m_TransitTunnelBuildMsgHandler.HandleVariableTransitTunnelBuildMsg (std::move (msg)); + m_TransitTunnels.HandleVariableTransitTunnelBuildMsg (std::move (msg)); } void Tunnels::HandleTunnelBuildReplyMsg (std::shared_ptr msg, bool isShort) @@ -711,7 +710,7 @@ namespace tunnel ManagePendingTunnels (ts); ManageInboundTunnels (ts); ManageOutboundTunnels (ts); - ManageTransitTunnels (ts); + m_TransitTunnels.ManageTransitTunnels (ts); } void Tunnels::ManagePendingTunnels (uint64_t ts) @@ -838,7 +837,7 @@ namespace tunnel auto pool = tunnel->GetTunnelPool (); if (pool) pool->TunnelExpired (tunnel); - m_Tunnels.erase (tunnel->GetTunnelID ()); + RemoveTunnel (tunnel->GetTunnelID ()); it = m_InboundTunnels.erase (it); } else @@ -900,26 +899,6 @@ namespace tunnel } } - void Tunnels::ManageTransitTunnels (uint64_t ts) - { - for (auto it = m_TransitTunnels.begin (); it != m_TransitTunnels.end ();) - { - auto tunnel = *it; - if (ts > tunnel->GetCreationTime () + TUNNEL_EXPIRATION_TIMEOUT || - ts + TUNNEL_EXPIRATION_TIMEOUT < tunnel->GetCreationTime ()) - { - LogPrint (eLogDebug, "Tunnel: Transit tunnel with id ", tunnel->GetTunnelID (), " expired"); - m_Tunnels.erase (tunnel->GetTunnelID ()); - it = m_TransitTunnels.erase (it); - } - else - { - tunnel->Cleanup (); - it++; - } - } - } - void Tunnels::ManageTunnelPools (uint64_t ts) { std::unique_lock l(m_PoolsMutex); @@ -993,7 +972,7 @@ namespace tunnel void Tunnels::AddInboundTunnel (std::shared_ptr newTunnel) { - if (m_Tunnels.emplace (newTunnel->GetTunnelID (), newTunnel).second) + if (AddTunnel (newTunnel)) { m_InboundTunnels.push_back (newTunnel); auto pool = newTunnel->GetTunnelPool (); @@ -1023,7 +1002,7 @@ namespace tunnel inboundTunnel->SetTunnelPool (pool); inboundTunnel->SetState (eTunnelStateEstablished); m_InboundTunnels.push_back (inboundTunnel); - m_Tunnels[inboundTunnel->GetTunnelID ()] = inboundTunnel; + AddTunnel (inboundTunnel); return inboundTunnel; } @@ -1057,21 +1036,12 @@ namespace tunnel int Tunnels::GetTransitTunnelsExpirationTimeout () { - int timeout = 0; - uint32_t ts = i2p::util::GetSecondsSinceEpoch (); - // TODO: possible race condition with I2PControl - for (const auto& it : m_TransitTunnels) - { - int t = it->GetCreationTime () + TUNNEL_EXPIRATION_TIMEOUT - ts; - if (t > timeout) timeout = t; - } - return timeout; + return m_TransitTunnels.GetTransitTunnelsExpirationTimeout (); } size_t Tunnels::CountTransitTunnels() const { - // TODO: locking - return m_TransitTunnels.size(); + return m_TransitTunnels.GetNumTransitTunnels (); } size_t Tunnels::CountInboundTunnels() const diff --git a/libi2pd/Tunnel.h b/libi2pd/Tunnel.h index f4d94ba7..fcccd236 100644 --- a/libi2pd/Tunnel.h +++ b/libi2pd/Tunnel.h @@ -223,8 +223,9 @@ namespace tunnel std::shared_ptr GetNextOutboundTunnel (); std::shared_ptr GetExploratoryPool () const { return m_ExploratoryPool; }; std::shared_ptr GetTunnel (uint32_t tunnelID); + bool AddTunnel (std::shared_ptr tunnel); + void RemoveTunnel (uint32_t tunnelID); int GetTransitTunnelsExpirationTimeout (); - bool AddTransitTunnel (std::shared_ptr tunnel); void AddOutboundTunnel (std::shared_ptr newTunnel); void AddInboundTunnel (std::shared_ptr newTunnel); std::shared_ptr CreateInboundTunnel (std::shared_ptr config, std::shared_ptr pool, std::shared_ptr outboundTunnel); @@ -243,7 +244,7 @@ namespace tunnel void SetMaxNumTransitTunnels (uint32_t maxNumTransitTunnels); uint32_t GetMaxNumTransitTunnels () const { return m_MaxNumTransitTunnels; }; - int GetCongestionLevel() const { return m_MaxNumTransitTunnels ? CONGESTION_LEVEL_FULL * m_TransitTunnels.size() / m_MaxNumTransitTunnels : CONGESTION_LEVEL_FULL; } + int GetCongestionLevel() const { return m_MaxNumTransitTunnels ? CONGESTION_LEVEL_FULL * m_TransitTunnels.GetNumTransitTunnels () / m_MaxNumTransitTunnels : CONGESTION_LEVEL_FULL; } std::mt19937& GetRng () { return m_Rng; }; @@ -265,7 +266,6 @@ namespace tunnel void ManageTunnels (uint64_t ts); void ManageOutboundTunnels (uint64_t ts); void ManageInboundTunnels (uint64_t ts); - void ManageTransitTunnels (uint64_t ts); void ManagePendingTunnels (uint64_t ts); template void ManagePendingTunnels (PendingTunnels& pendingTunnels, uint64_t ts); @@ -300,7 +300,6 @@ namespace tunnel std::map > m_PendingOutboundTunnels; // by replyMsgID std::list > m_InboundTunnels; std::list > m_OutboundTunnels; - std::list > m_TransitTunnels; std::unordered_map > m_Tunnels; // tunnelID->tunnel known by this id std::mutex m_PoolsMutex; std::list> m_Pools; @@ -314,14 +313,14 @@ namespace tunnel double m_TunnelCreationSuccessRate; int m_TunnelCreationAttemptsNum; std::mt19937 m_Rng; - TransitTunnelBuildMsgHandler m_TransitTunnelBuildMsgHandler; + TransitTunnels m_TransitTunnels; public: // for HTTP only const decltype(m_OutboundTunnels)& GetOutboundTunnels () const { return m_OutboundTunnels; }; const decltype(m_InboundTunnels)& GetInboundTunnels () const { return m_InboundTunnels; }; - const decltype(m_TransitTunnels)& GetTransitTunnels () const { return m_TransitTunnels; }; + auto& GetTransitTunnels () const { return m_TransitTunnels.GetTransitTunnels (); }; size_t CountTransitTunnels() const; size_t CountInboundTunnels() const;