diff --git a/TunnelConfig.h b/TunnelConfig.h index 429fa679..1e833f78 100644 --- a/TunnelConfig.h +++ b/TunnelConfig.h @@ -51,7 +51,7 @@ namespace tunnel nextTunnelID = rnd.GenerateWord32 (); } - void SetReplyHop (TunnelHopConfig * replyFirstHop) + void SetReplyHop (const TunnelHopConfig * replyFirstHop) { nextRouter = replyFirstHop->router; nextTunnelID = replyFirstHop->tunnelID; @@ -89,7 +89,7 @@ namespace tunnel TunnelConfig (std::vector peers, - TunnelConfig * replyTunnelConfig = 0) // replyTunnelConfig=0 means inbound + const TunnelConfig * replyTunnelConfig = nullptr) // replyTunnelConfig=0 means inbound { TunnelHopConfig * prev = nullptr; for (auto it: peers) @@ -192,6 +192,27 @@ namespace tunnel } return newConfig; } + + TunnelConfig * Clone (const TunnelConfig * replyTunnelConfig = nullptr) const + { + TunnelConfig * newConfig = new TunnelConfig (); + TunnelHopConfig * hop = m_FirstHop, * prev = nullptr; + while (hop) + { + TunnelHopConfig * newHop = new TunnelHopConfig (hop->router); + newHop->SetPrev (prev); + newHop->SetNextRouter (hop->nextRouter); + newHop->isGateway = hop->isGateway; + newHop->isEndpoint = hop->isEndpoint; + prev = newHop; + if (!newConfig->m_FirstHop) newConfig->m_FirstHop = newHop; + hop = hop->next; + } + newConfig->m_LastHop = prev; + if (replyTunnelConfig && newConfig->m_LastHop) + newConfig->m_LastHop->SetReplyHop (replyTunnelConfig->GetFirstHop ()); + return newConfig; + } private: diff --git a/TunnelPool.cpp b/TunnelPool.cpp index e370cfe7..b0e12e5e 100644 --- a/TunnelPool.cpp +++ b/TunnelPool.cpp @@ -36,7 +36,7 @@ namespace tunnel m_InboundTunnels.erase (expiredTunnel); for (auto it: m_Tests) if (it.second.second == expiredTunnel) it.second.second = nullptr; - + RecreateInboundTunnel (expiredTunnel); } } @@ -53,6 +53,7 @@ namespace tunnel m_OutboundTunnels.erase (expiredTunnel); for (auto it: m_Tests) if (it.second.first == expiredTunnel) it.second.first = nullptr; + RecreateOutboundTunnel (expiredTunnel); } } @@ -204,6 +205,16 @@ namespace tunnel tunnel->SetTunnelPool (this); } + void TunnelPool::RecreateInboundTunnel (InboundTunnel * tunnel) + { + OutboundTunnel * outboundTunnel = GetNextOutboundTunnel (); + if (!outboundTunnel) + outboundTunnel = tunnels.GetNextOutboundTunnel (); + LogPrint ("Re-creating destination inbound tunnel..."); + auto * newTunnel = tunnels.CreateTunnel (tunnel->GetTunnelConfig ()->Clone (), outboundTunnel); + newTunnel->SetTunnelPool (this); + } + void TunnelPool::CreateOutboundTunnel () { InboundTunnel * inboundTunnel = m_InboundTunnels.size () > 0 ? @@ -226,5 +237,16 @@ namespace tunnel tunnel->SetTunnelPool (this); } } + + void TunnelPool::RecreateOutboundTunnel (OutboundTunnel * tunnel) + { + InboundTunnel * inboundTunnel = GetNextInboundTunnel (); + if (!inboundTunnel) + inboundTunnel = tunnels.GetNextInboundTunnel (); + LogPrint ("Re-creating destination outbound tunnel..."); + auto * newTunnel = tunnels.CreateTunnel ( + tunnel->GetTunnelConfig ()->Clone (inboundTunnel->GetTunnelConfig ())); + newTunnel->SetTunnelPool (this); + } } } diff --git a/TunnelPool.h b/TunnelPool.h index 21b8f6c5..15682ace 100644 --- a/TunnelPool.h +++ b/TunnelPool.h @@ -48,6 +48,8 @@ namespace tunnel void CreateInboundTunnel (); void CreateOutboundTunnel (); + void RecreateInboundTunnel (InboundTunnel * tunnel); + void RecreateOutboundTunnel (OutboundTunnel * tunnel); template typename TTunnels::value_type GetNextTunnel (TTunnels& tunnels);