From 902768099cb92b39f5ff509cd91fdd8970759b8a Mon Sep 17 00:00:00 2001 From: Cory Fields Date: Wed, 26 Oct 2016 18:08:11 -0400 Subject: [PATCH] net: handle version push in InitializeNode --- src/main.cpp | 34 ++++++++++++++++++++++++++++++---- src/net.cpp | 34 +++++----------------------------- src/net.h | 9 +++++---- src/test/DoS_tests.cpp | 8 ++++---- 4 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 2d02f59d2..590f19526 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -18,6 +18,7 @@ #include "init.h" #include "merkleblock.h" #include "net.h" +#include "netbase.h" #include "policy/fees.h" #include "policy/policy.h" #include "pow.h" @@ -354,11 +355,36 @@ void UpdatePreferredDownload(CNode* node, CNodeState* state) nPreferredDownload += state->fPreferredDownload; } -void InitializeNode(NodeId nodeid, const CNode *pnode) { +void PushNodeVersion(CNode *pnode, CConnman& connman, int64_t nTime) +{ + ServiceFlags nLocalNodeServices = pnode->GetLocalServices(); + uint64_t nonce = pnode->GetLocalNonce(); + int nNodeStartingHeight = pnode->GetMyStartingHeight(); + NodeId nodeid = pnode->GetId(); + CAddress addr = pnode->addr; + + CAddress addrYou = (addr.IsRoutable() && !IsProxy(addr) ? addr : CAddress(CService(), addr.nServices)); + CAddress addrMe = CAddress(CService(), nLocalNodeServices); + + connman.PushMessageWithVersion(pnode, INIT_PROTO_VERSION, NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalNodeServices, nTime, addrYou, addrMe, + nonce, strSubVersion, nNodeStartingHeight, ::fRelayTxes); + + if (fLogIPs) + LogPrint("net", "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), addrYou.ToString(), nodeid); + else + LogPrint("net", "send version message: version %d, blocks=%d, us=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), nodeid); +} + +void InitializeNode(CNode *pnode, CConnman& connman) { CAddress addr = pnode->addr; std::string addrName = pnode->addrName; - LOCK(cs_main); - mapNodeState.emplace_hint(mapNodeState.end(), std::piecewise_construct, std::forward_as_tuple(nodeid), std::forward_as_tuple(addr, std::move(addrName))); + NodeId nodeid = pnode->GetId(); + { + LOCK(cs_main); + mapNodeState.emplace_hint(mapNodeState.end(), std::piecewise_construct, std::forward_as_tuple(nodeid), std::forward_as_tuple(addr, std::move(addrName))); + } + if(!pnode->fInbound) + PushNodeVersion(pnode, connman, GetTime()); } void FinalizeNode(NodeId nodeid, bool& fUpdateConnectionTime) { @@ -5118,7 +5144,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, // Be shy and don't send version until we hear if (pfrom->fInbound) - connman.PushVersion(pfrom, GetAdjustedTime()); + PushNodeVersion(pfrom, connman, GetAdjustedTime()); pfrom->fClient = !(pfrom->nServices & NODE_NETWORK); diff --git a/src/net.cpp b/src/net.cpp index 16d9527ee..15cf7ce8b 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -393,21 +393,15 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); CNode* pnode = new CNode(id, nLocalServices, GetBestHeight(), hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, pszDest ? pszDest : "", false); - - - PushVersion(pnode, GetTime()); - - GetNodeSignals().InitializeNode(pnode->GetId(), pnode); + pnode->nServicesExpected = ServiceFlags(addrConnect.nServices & nRelevantServices); + pnode->nTimeConnected = GetTime(); pnode->AddRef(); - + GetNodeSignals().InitializeNode(pnode, *this); { LOCK(cs_vNodes); vNodes.push_back(pnode); } - pnode->nServicesExpected = ServiceFlags(addrConnect.nServices & nRelevantServices); - pnode->nTimeConnected = GetTime(); - return pnode; } else if (!proxyConnectionFailed) { // If connecting to the node failed, and failure is not caused by a problem connecting to @@ -418,24 +412,6 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo return NULL; } -void CConnman::PushVersion(CNode* pnode, int64_t nTime) -{ - ServiceFlags nLocalNodeServices = pnode->GetLocalServices(); - CAddress addrYou = (pnode->addr.IsRoutable() && !IsProxy(pnode->addr) ? pnode->addr : CAddress(CService(), pnode->addr.nServices)); - CAddress addrMe = CAddress(CService(), nLocalNodeServices); - uint64_t nonce = pnode->GetLocalNonce(); - int nNodeStartingHeight = pnode->nMyStartingHeight; - NodeId id = pnode->GetId(); - - PushMessageWithVersion(pnode, INIT_PROTO_VERSION, NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalNodeServices, nTime, addrYou, addrMe, - nonce, strSubVersion, nNodeStartingHeight, ::fRelayTxes); - - if (fLogIPs) - LogPrint("net", "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), addrYou.ToString(), id); - else - LogPrint("net", "send version message: version %d, blocks=%d, us=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), id); -} - void CConnman::DumpBanlist() { SweepBanned(); // clean unused entries (if bantime has expired) @@ -1036,9 +1012,9 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); CNode* pnode = new CNode(id, nLocalServices, GetBestHeight(), hSocket, addr, CalculateKeyedNetGroup(addr), nonce, "", true); - GetNodeSignals().InitializeNode(pnode->GetId(), pnode); pnode->AddRef(); pnode->fWhitelisted = whitelisted; + GetNodeSignals().InitializeNode(pnode, *this); LogPrint("net", "connection from %s accepted\n", addr.ToString()); @@ -2130,7 +2106,7 @@ bool CConnman::Start(boost::thread_group& threadGroup, CScheduler& scheduler, st uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); pnodeLocalHost = new CNode(id, nLocalServices, GetBestHeight(), INVALID_SOCKET, CAddress(CService(local, 0), nLocalServices), 0, nonce); - GetNodeSignals().InitializeNode(pnodeLocalHost->GetId(), pnodeLocalHost); + GetNodeSignals().InitializeNode(pnodeLocalHost, *this); } // diff --git a/src/net.h b/src/net.h index 99ca5bc0d..22b80fc50 100644 --- a/src/net.h +++ b/src/net.h @@ -163,9 +163,6 @@ public: PushMessageWithVersionAndFlag(pnode, 0, 0, sCommand, std::forward(args)...); } - void PushVersion(CNode* pnode, int64_t nTime); - - template bool ForEachNodeContinueIf(Callable&& func) { @@ -462,7 +459,7 @@ struct CNodeSignals { boost::signals2::signal ProcessMessages; boost::signals2::signal SendMessages; - boost::signals2::signal InitializeNode; + boost::signals2::signal InitializeNode; boost::signals2::signal FinalizeNode; }; @@ -722,6 +719,10 @@ public: return nLocalHostNonce; } + int GetMyStartingHeight() const { + return nMyStartingHeight; + } + int GetRefCount() { assert(nRefCount >= 0); diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index 47979a6f8..6eed63608 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) CAddress addr1(ip(0xa0b0c001), NODE_NONE); CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 0, 0, "", true); dummyNode1.SetSendVersion(PROTOCOL_VERSION); - GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1); + GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); // Should get banned SendMessages(&dummyNode1, *connman); @@ -60,7 +60,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) CAddress addr2(ip(0xa0b0c002), NODE_NONE); CNode dummyNode2(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr2, 1, 1, "", true); dummyNode2.SetSendVersion(PROTOCOL_VERSION); - GetNodeSignals().InitializeNode(dummyNode2.GetId(), &dummyNode2); + GetNodeSignals().InitializeNode(&dummyNode2, *connman); dummyNode2.nVersion = 1; Misbehaving(dummyNode2.GetId(), 50); SendMessages(&dummyNode2, *connman); @@ -78,7 +78,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) CAddress addr1(ip(0xa0b0c001), NODE_NONE); CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 3, 1, "", true); dummyNode1.SetSendVersion(PROTOCOL_VERSION); - GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1); + GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); SendMessages(&dummyNode1, *connman); @@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) CAddress addr(ip(0xa0b0c001), NODE_NONE); CNode dummyNode(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr, 4, 4, "", true); dummyNode.SetSendVersion(PROTOCOL_VERSION); - GetNodeSignals().InitializeNode(dummyNode.GetId(), &dummyNode); + GetNodeSignals().InitializeNode(&dummyNode, *connman); dummyNode.nVersion = 1; Misbehaving(dummyNode.GetId(), 100);