Browse Source

Merge #9609: net: fix remaining net assertions

08bb6f4 net: log an error rather than asserting if send version is misused (Cory Fields)
7a8c251 net: Disallow sending messages until the version handshake is complete (Cory Fields)
12752af net: don't run callbacks on nodes that haven't completed the version handshake (Cory Fields)
2046617 net: deserialize the entire version message locally (Cory Fields)
80ff034 Dont deserialize nVersion into CNode, should fix #9212 (Matt Corallo)
0.14
Wladimir J. van der Laan 8 years ago
parent
commit
496691741d
No known key found for this signature in database
GPG Key ID: 74810B012346C9A6
  1. 34
      src/net.cpp
  2. 100
      src/net.h
  3. 75
      src/net_processing.cpp
  4. 4
      src/test/DoS_tests.cpp

34
src/net.cpp

@ -689,6 +689,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
return true; return true;
} }
void CNode::SetSendVersion(int nVersionIn)
{
// Send version may only be changed in the version message, and
// only one version message is allowed per session. We can therefore
// treat this value as const and even atomic as long as it's only used
// once a version message has been successfully processed. Any attempt to
// set this twice is an error.
if (nSendVersion != 0) {
error("Send version already set for node: %i. Refusing to change from %i to %i", id, nSendVersion, nVersionIn);
} else {
nSendVersion = nVersionIn;
}
}
int CNode::GetSendVersion() const
{
// The send version should always be explicitly set to
// INIT_PROTO_VERSION rather than using this value until SetSendVersion
// has been called.
if (nSendVersion == 0) {
error("Requesting unset send version for node: %i. Using %i", id, INIT_PROTO_VERSION);
return INIT_PROTO_VERSION;
}
return nSendVersion;
}
int CNetMessage::readHeader(const char *pch, unsigned int nBytes) int CNetMessage::readHeader(const char *pch, unsigned int nBytes)
{ {
// copy data to temporary parsing buffer // copy data to temporary parsing buffer
@ -2630,6 +2657,11 @@ void CNode::AskFor(const CInv& inv)
mapAskFor.insert(std::make_pair(nRequestTime, inv)); mapAskFor.insert(std::make_pair(nRequestTime, inv));
} }
bool CConnman::NodeFullyConnected(const CNode* pnode)
{
return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect;
}
void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
{ {
size_t nMessageSize = msg.data.size(); size_t nMessageSize = msg.data.size();
@ -2680,7 +2712,7 @@ bool CConnman::ForNode(NodeId id, std::function<bool(CNode* pnode)> func)
break; break;
} }
} }
return found != nullptr && func(found); return found != nullptr && NodeFullyConnected(found) && func(found);
} }
int64_t PoissonNextSend(int64_t nNow, int average_interval_seconds) { int64_t PoissonNextSend(int64_t nNow, int average_interval_seconds) {

100
src/net.h

@ -161,76 +161,34 @@ public:
void PushMessage(CNode* pnode, CSerializedNetMsg&& msg); void PushMessage(CNode* pnode, CSerializedNetMsg&& msg);
template<typename Callable>
bool ForEachNodeContinueIf(Callable&& func)
{
LOCK(cs_vNodes);
for (auto&& node : vNodes)
if(!func(node))
return false;
return true;
};
template<typename Callable>
bool ForEachNodeContinueIf(Callable&& func) const
{
LOCK(cs_vNodes);
for (const auto& node : vNodes)
if(!func(node))
return false;
return true;
};
template<typename Callable, typename CallableAfter>
bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post)
{
bool ret = true;
LOCK(cs_vNodes);
for (auto&& node : vNodes)
if(!pre(node)) {
ret = false;
break;
}
post();
return ret;
};
template<typename Callable, typename CallableAfter>
bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post) const
{
bool ret = true;
LOCK(cs_vNodes);
for (const auto& node : vNodes)
if(!pre(node)) {
ret = false;
break;
}
post();
return ret;
};
template<typename Callable> template<typename Callable>
void ForEachNode(Callable&& func) void ForEachNode(Callable&& func)
{ {
LOCK(cs_vNodes); LOCK(cs_vNodes);
for (auto&& node : vNodes) for (auto&& node : vNodes) {
func(node); if (NodeFullyConnected(node))
func(node);
}
}; };
template<typename Callable> template<typename Callable>
void ForEachNode(Callable&& func) const void ForEachNode(Callable&& func) const
{ {
LOCK(cs_vNodes); LOCK(cs_vNodes);
for (const auto& node : vNodes) for (auto&& node : vNodes) {
func(node); if (NodeFullyConnected(node))
func(node);
}
}; };
template<typename Callable, typename CallableAfter> template<typename Callable, typename CallableAfter>
void ForEachNodeThen(Callable&& pre, CallableAfter&& post) void ForEachNodeThen(Callable&& pre, CallableAfter&& post)
{ {
LOCK(cs_vNodes); LOCK(cs_vNodes);
for (auto&& node : vNodes) for (auto&& node : vNodes) {
pre(node); if (NodeFullyConnected(node))
pre(node);
}
post(); post();
}; };
@ -238,8 +196,10 @@ public:
void ForEachNodeThen(Callable&& pre, CallableAfter&& post) const void ForEachNodeThen(Callable&& pre, CallableAfter&& post) const
{ {
LOCK(cs_vNodes); LOCK(cs_vNodes);
for (const auto& node : vNodes) for (auto&& node : vNodes) {
pre(node); if (NodeFullyConnected(node))
pre(node);
}
post(); post();
}; };
@ -372,6 +332,9 @@ private:
void RecordBytesRecv(uint64_t bytes); void RecordBytesRecv(uint64_t bytes);
void RecordBytesSent(uint64_t bytes); void RecordBytesSent(uint64_t bytes);
// Whether the node should be passed out in ForEach* callbacks
static bool NodeFullyConnected(const CNode* pnode);
// Network usage totals // Network usage totals
CCriticalSection cs_totalBytesRecv; CCriticalSection cs_totalBytesRecv;
CCriticalSection cs_totalBytesSent; CCriticalSection cs_totalBytesSent;
@ -627,7 +590,7 @@ public:
const CAddress addr; const CAddress addr;
std::string addrName; std::string addrName;
CService addrLocal; CService addrLocal;
int nVersion; std::atomic<int> nVersion;
// strSubVer is whatever byte array we read from the wire. However, this field is intended // strSubVer is whatever byte array we read from the wire. However, this field is intended
// to be printed out, displayed to humans in various forms and so on. So we sanitize it and // to be printed out, displayed to humans in various forms and so on. So we sanitize it and
// store the sanitized version in cleanSubVer. The original should be used when dealing with // store the sanitized version in cleanSubVer. The original should be used when dealing with
@ -639,7 +602,7 @@ public:
bool fAddnode; bool fAddnode;
bool fClient; bool fClient;
const bool fInbound; const bool fInbound;
bool fSuccessfullyConnected; std::atomic_bool fSuccessfullyConnected;
std::atomic_bool fDisconnect; std::atomic_bool fDisconnect;
// We use fRelayTxes for two purposes - // We use fRelayTxes for two purposes -
// a) it allows us to not relay tx invs before receiving the peer's version message // a) it allows us to not relay tx invs before receiving the peer's version message
@ -760,25 +723,8 @@ public:
{ {
return nRecvVersion; return nRecvVersion;
} }
void SetSendVersion(int nVersionIn) void SetSendVersion(int nVersionIn);
{ int GetSendVersion() const;
// Send version may only be changed in the version message, and
// only one version message is allowed per session. We can therefore
// treat this value as const and even atomic as long as it's only used
// once the handshake is complete. Any attempt to set this twice is an
// error.
assert(nSendVersion == 0);
nSendVersion = nVersionIn;
}
int GetSendVersion() const
{
// The send version should always be explicitly set to
// INIT_PROTO_VERSION rather than using this value until the handshake
// is complete.
assert(nSendVersion != 0);
return nSendVersion;
}
CNode* AddRef() CNode* AddRef()
{ {

75
src/net_processing.cpp

@ -1199,50 +1199,51 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
CAddress addrFrom; CAddress addrFrom;
uint64_t nNonce = 1; uint64_t nNonce = 1;
uint64_t nServiceInt; uint64_t nServiceInt;
vRecv >> pfrom->nVersion >> nServiceInt >> nTime >> addrMe; ServiceFlags nServices;
pfrom->nServices = ServiceFlags(nServiceInt); int nVersion;
int nSendVersion;
std::string strSubVer;
int nStartingHeight = -1;
bool fRelay = true;
vRecv >> nVersion >> nServiceInt >> nTime >> addrMe;
nSendVersion = std::min(nVersion, PROTOCOL_VERSION);
nServices = ServiceFlags(nServiceInt);
if (!pfrom->fInbound) if (!pfrom->fInbound)
{ {
connman.SetServices(pfrom->addr, pfrom->nServices); connman.SetServices(pfrom->addr, nServices);
} }
if (pfrom->nServicesExpected & ~pfrom->nServices) if (pfrom->nServicesExpected & ~nServices)
{ {
LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, pfrom->nServices, pfrom->nServicesExpected); LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, nServices, pfrom->nServicesExpected);
connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD, connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD,
strprintf("Expected to offer services %08x", pfrom->nServicesExpected))); strprintf("Expected to offer services %08x", pfrom->nServicesExpected)));
pfrom->fDisconnect = true; pfrom->fDisconnect = true;
return false; return false;
} }
if (pfrom->nVersion < MIN_PEER_PROTO_VERSION) if (nVersion < MIN_PEER_PROTO_VERSION)
{ {
// disconnect from peers older than this proto version // disconnect from peers older than this proto version
LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, pfrom->nVersion); LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, nVersion);
connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_OBSOLETE, connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_OBSOLETE,
strprintf("Version must be %d or greater", MIN_PEER_PROTO_VERSION))); strprintf("Version must be %d or greater", MIN_PEER_PROTO_VERSION)));
pfrom->fDisconnect = true; pfrom->fDisconnect = true;
return false; return false;
} }
if (pfrom->nVersion == 10300) if (nVersion == 10300)
pfrom->nVersion = 300; nVersion = 300;
if (!vRecv.empty()) if (!vRecv.empty())
vRecv >> addrFrom >> nNonce; vRecv >> addrFrom >> nNonce;
if (!vRecv.empty()) { if (!vRecv.empty()) {
vRecv >> LIMITED_STRING(pfrom->strSubVer, MAX_SUBVERSION_LENGTH); vRecv >> LIMITED_STRING(strSubVer, MAX_SUBVERSION_LENGTH);
pfrom->cleanSubVer = SanitizeString(pfrom->strSubVer);
} }
if (!vRecv.empty()) { if (!vRecv.empty()) {
vRecv >> pfrom->nStartingHeight; vRecv >> nStartingHeight;
} }
{ if (!vRecv.empty())
LOCK(pfrom->cs_filter); vRecv >> fRelay;
if (!vRecv.empty())
vRecv >> pfrom->fRelayTxes; // set to true after we get the first filter* message
else
pfrom->fRelayTxes = true;
}
// Disconnect if we connected to ourself // Disconnect if we connected to ourself
if (pfrom->fInbound && !connman.CheckIncomingNonce(nNonce)) if (pfrom->fInbound && !connman.CheckIncomingNonce(nNonce))
{ {
@ -1251,7 +1252,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
return true; return true;
} }
pfrom->addrLocal = addrMe;
if (pfrom->fInbound && addrMe.IsRoutable()) if (pfrom->fInbound && addrMe.IsRoutable())
{ {
SeenLocal(addrMe); SeenLocal(addrMe);
@ -1261,9 +1261,24 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
if (pfrom->fInbound) if (pfrom->fInbound)
PushNodeVersion(pfrom, connman, GetAdjustedTime()); PushNodeVersion(pfrom, connman, GetAdjustedTime());
pfrom->fClient = !(pfrom->nServices & NODE_NETWORK); connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK));
if((pfrom->nServices & NODE_WITNESS)) pfrom->nServices = nServices;
pfrom->addrLocal = addrMe;
pfrom->strSubVer = strSubVer;
pfrom->cleanSubVer = SanitizeString(strSubVer);
pfrom->nStartingHeight = nStartingHeight;
pfrom->fClient = !(nServices & NODE_NETWORK);
{
LOCK(pfrom->cs_filter);
pfrom->fRelayTxes = fRelay; // set to true after we get the first filter* message
}
// Change version
pfrom->SetSendVersion(nSendVersion);
pfrom->nVersion = nVersion;
if((nServices & NODE_WITNESS))
{ {
LOCK(cs_main); LOCK(cs_main);
State(pfrom->GetId())->fHaveWitness = true; State(pfrom->GetId())->fHaveWitness = true;
@ -1275,11 +1290,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
UpdatePreferredDownload(pfrom, State(pfrom->GetId())); UpdatePreferredDownload(pfrom, State(pfrom->GetId()));
} }
// Change version
connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK));
int nSendVersion = std::min(pfrom->nVersion, PROTOCOL_VERSION);
pfrom->SetSendVersion(nSendVersion);
if (!pfrom->fInbound) if (!pfrom->fInbound)
{ {
// Advertise our address // Advertise our address
@ -1307,8 +1317,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
connman.MarkAddressGood(pfrom->addr); connman.MarkAddressGood(pfrom->addr);
} }
pfrom->fSuccessfullyConnected = true;
std::string remoteAddr; std::string remoteAddr;
if (fLogIPs) if (fLogIPs)
remoteAddr = ", peeraddr=" + pfrom->addr.ToString(); remoteAddr = ", peeraddr=" + pfrom->addr.ToString();
@ -1350,7 +1358,7 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
if (strCommand == NetMsgType::VERACK) if (strCommand == NetMsgType::VERACK)
{ {
pfrom->SetRecvVersion(std::min(pfrom->nVersion, PROTOCOL_VERSION)); pfrom->SetRecvVersion(std::min(pfrom->nVersion.load(), PROTOCOL_VERSION));
if (!pfrom->fInbound) { if (!pfrom->fInbound) {
// Mark this node as currently connected, so we update its timestamp later. // Mark this node as currently connected, so we update its timestamp later.
@ -1378,6 +1386,7 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR
nCMPCTBLOCKVersion = 1; nCMPCTBLOCKVersion = 1;
connman.PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); connman.PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion));
} }
pfrom->fSuccessfullyConnected = true;
} }
@ -2716,8 +2725,8 @@ bool SendMessages(CNode* pto, CConnman& connman, std::atomic<bool>& interruptMsg
{ {
const Consensus::Params& consensusParams = Params().GetConsensus(); const Consensus::Params& consensusParams = Params().GetConsensus();
{ {
// Don't send anything until we get its version message // Don't send anything until the version handshake is complete
if (pto->nVersion == 0 || pto->fDisconnect) if (!pto->fSuccessfullyConnected || pto->fDisconnect)
return true; return true;
// If we get here, the outgoing message serialization version is set and can't change. // If we get here, the outgoing message serialization version is set and can't change.

4
src/test/DoS_tests.cpp

@ -55,6 +55,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning)
dummyNode1.SetSendVersion(PROTOCOL_VERSION); dummyNode1.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(&dummyNode1, *connman); GetNodeSignals().InitializeNode(&dummyNode1, *connman);
dummyNode1.nVersion = 1; dummyNode1.nVersion = 1;
dummyNode1.fSuccessfullyConnected = true;
Misbehaving(dummyNode1.GetId(), 100); // Should get banned Misbehaving(dummyNode1.GetId(), 100); // Should get banned
SendMessages(&dummyNode1, *connman, interruptDummy); SendMessages(&dummyNode1, *connman, interruptDummy);
BOOST_CHECK(connman->IsBanned(addr1)); BOOST_CHECK(connman->IsBanned(addr1));
@ -65,6 +66,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning)
dummyNode2.SetSendVersion(PROTOCOL_VERSION); dummyNode2.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(&dummyNode2, *connman); GetNodeSignals().InitializeNode(&dummyNode2, *connman);
dummyNode2.nVersion = 1; dummyNode2.nVersion = 1;
dummyNode2.fSuccessfullyConnected = true;
Misbehaving(dummyNode2.GetId(), 50); Misbehaving(dummyNode2.GetId(), 50);
SendMessages(&dummyNode2, *connman, interruptDummy); SendMessages(&dummyNode2, *connman, interruptDummy);
BOOST_CHECK(!connman->IsBanned(addr2)); // 2 not banned yet... BOOST_CHECK(!connman->IsBanned(addr2)); // 2 not banned yet...
@ -85,6 +87,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore)
dummyNode1.SetSendVersion(PROTOCOL_VERSION); dummyNode1.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(&dummyNode1, *connman); GetNodeSignals().InitializeNode(&dummyNode1, *connman);
dummyNode1.nVersion = 1; dummyNode1.nVersion = 1;
dummyNode1.fSuccessfullyConnected = true;
Misbehaving(dummyNode1.GetId(), 100); Misbehaving(dummyNode1.GetId(), 100);
SendMessages(&dummyNode1, *connman, interruptDummy); SendMessages(&dummyNode1, *connman, interruptDummy);
BOOST_CHECK(!connman->IsBanned(addr1)); BOOST_CHECK(!connman->IsBanned(addr1));
@ -110,6 +113,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime)
dummyNode.SetSendVersion(PROTOCOL_VERSION); dummyNode.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(&dummyNode, *connman); GetNodeSignals().InitializeNode(&dummyNode, *connman);
dummyNode.nVersion = 1; dummyNode.nVersion = 1;
dummyNode.fSuccessfullyConnected = true;
Misbehaving(dummyNode.GetId(), 100); Misbehaving(dummyNode.GetId(), 100);
SendMessages(&dummyNode, *connman, interruptDummy); SendMessages(&dummyNode, *connman, interruptDummy);

Loading…
Cancel
Save