diff --git a/src/main.cpp b/src/main.cpp index af598d487..43ccb6374 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5025,7 +5025,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, } // Disconnect if we connected to ourself - if (nNonce == nLocalHostNonce && nNonce > 1) + if (pfrom->fInbound && !connman.CheckIncomingNonce(nNonce)) { LogPrintf("connected to self at %s, disconnecting\n", pfrom->addr.ToString()); pfrom->fDisconnect = true; diff --git a/src/net.cpp b/src/net.cpp index 8bc8ecc43..71b4b0168 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -83,7 +83,6 @@ CCriticalSection cs_mapLocalHost; std::map mapLocalHost; static bool vfLimited[NET_MAX] = {}; static CNode* pnodeLocalHost = NULL; -uint64_t nLocalHostNonce = 0; int nMaxConnections = DEFAULT_MAX_PEER_CONNECTIONS; std::string strSubVersion; @@ -346,6 +345,16 @@ CNode* CConnman::FindNode(const CService& addr) return NULL; } +bool CConnman::CheckIncomingNonce(uint64_t nonce) +{ + LOCK(cs_vNodes); + BOOST_FOREACH(CNode* pnode, vNodes) { + if (!pnode->fSuccessfullyConnected && !pnode->fInbound && pnode->GetLocalNonce() == nonce) + return false; + } + return true; +} + CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure) { if (pszDest == NULL) { @@ -465,7 +474,6 @@ void CNode::PushVersion() int64_t nTime = (fInbound ? GetAdjustedTime() : GetTime()); CAddress addrYou = (addr.IsRoutable() && !IsProxy(addr) ? addr : CAddress(CService(), addr.nServices)); CAddress addrMe = GetLocalAddress(&addr); - GetRandBytes((unsigned char*)&nLocalHostNonce, sizeof(nLocalHostNonce)); if (fLogIPs) LogPrint("net", "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nBestHeight, addrMe.ToString(), addrYou.ToString(), id); else @@ -2535,6 +2543,8 @@ CNode::CNode(NodeId idIn, SOCKET hSocketIn, const CAddress& addrIn, const std::s nextSendTimeFeeFilter = 0; id = idIn; + GetRandBytes((unsigned char*)&nLocalHostNonce, sizeof(nLocalHostNonce)); + BOOST_FOREACH(const std::string &msg, getAllNetMessageTypes()) mapRecvBytesPerMsgCmd[msg] = 0; mapRecvBytesPerMsgCmd[NET_MESSAGE_COMMAND_OTHER] = 0; diff --git a/src/net.h b/src/net.h index 36043b0b8..32668045c 100644 --- a/src/net.h +++ b/src/net.h @@ -114,6 +114,7 @@ public: void Stop(); bool BindListenPort(const CService &bindAddr, std::string& strError, bool fWhitelisted = false); bool OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound = NULL, const char *strDest = NULL, bool fOneShot = false, bool fFeeler = false); + bool CheckIncomingNonce(uint64_t nonce); bool ForNode(NodeId id, std::function func); bool ForEachNode(std::function func); @@ -297,7 +298,6 @@ extern bool fListen; extern ServiceFlags nLocalServices; extern ServiceFlags nRelevantServices; extern bool fRelayTxes; -extern uint64_t nLocalHostNonce; /** Maximum number of connections to simultaneously allow (aka connection slots) */ extern int nMaxConnections; @@ -523,12 +523,17 @@ private: static uint64_t CalculateKeyedNetGroup(const CAddress& ad); + uint64_t nLocalHostNonce; public: NodeId GetId() const { return id; } + uint64_t GetLocalNonce() const { + return nLocalHostNonce; + } + int GetRefCount() { assert(nRefCount >= 0);