diff --git a/src/init.cpp b/src/init.cpp index dbc2c413..158418d7 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -179,6 +179,8 @@ bool AppInit2(int argc, char* argv[]) " -addnode= \t " + _("Add a node to connect to\n") + " -connect= \t\t " + _("Connect only to the specified node\n") + " -nolisten \t " + _("Don't accept connections from outside\n") + + " -banscore= \t " + _("Threshold for disconnecting misbehaving peers (default: 100)\n") + + " -bantime= \t " + _("Number of seconds to keep misbehaving peers from reconnecting (default: 86400)\n") + #ifdef USE_UPNP #if USE_UPNP " -noupnp \t " + _("Don't attempt to use UPnP to map the listening port\n") + diff --git a/src/net.cpp b/src/net.cpp index 2e257a6e..1792bf78 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -726,6 +726,52 @@ void CNode::Cleanup() } +std::map CNode::setBanned; +CCriticalSection CNode::cs_setBanned; + +void CNode::ClearBanned() +{ + setBanned.clear(); +} + +bool CNode::IsBanned(unsigned int ip) +{ + bool fResult = false; + CRITICAL_BLOCK(cs_setBanned) + { + std::map::iterator i = setBanned.find(ip); + if (i != setBanned.end()) + { + int64 t = (*i).second; + if (GetTime() < t) + fResult = true; + } + } + return fResult; +} + +bool CNode::Misbehaving(int howmuch) +{ + if (addr.IsLocal()) + { + printf("Warning: local node %s misbehaving\n", addr.ToString().c_str()); + return false; + } + + nMisbehavior += howmuch; + if (nMisbehavior >= GetArg("-banscore", 100)) + { + int64 banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban + CRITICAL_BLOCK(cs_setBanned) + if (setBanned[addr.ip] < banTime) + setBanned[addr.ip] = banTime; + CloseSocketDisconnect(); + printf("Disconnected %s for misbehavior (score=%d)\n", addr.ToString().c_str(), nMisbehavior); + return true; + } + return false; +} + @@ -896,6 +942,11 @@ void ThreadSocketHandler2(void* parg) { closesocket(hSocket); } + else if (CNode::IsBanned(addr.ip)) + { + printf("connetion from %s dropped (banned)\n", addr.ToString().c_str()); + closesocket(hSocket); + } else { printf("accepted connection %s\n", addr.ToString().c_str()); @@ -1454,7 +1505,8 @@ bool OpenNetworkConnection(const CAddress& addrConnect) // if (fShutdown) return false; - if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() || FindNode(addrConnect.ip)) + if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() || + FindNode(addrConnect.ip) || CNode::IsBanned(addrConnect.ip)) return false; vnThreadsRunning[1]--; diff --git a/src/net.h b/src/net.h index 0026e402..5b3568fc 100644 --- a/src/net.h +++ b/src/net.h @@ -124,6 +124,13 @@ public: bool fDisconnect; protected: int nRefCount; + + // Denial-of-service detection/prevention + // Key is ip address, value is banned-until-time + static std::map setBanned; + static CCriticalSection cs_setBanned; + int nMisbehavior; + public: int64 nReleaseTime; std::map mapRequests; @@ -148,7 +155,6 @@ public: // publish and subscription std::vector vfSubscribe; - CNode(SOCKET hSocketIn, CAddress addrIn, bool fInboundIn=false) { nServices = 0; @@ -185,6 +191,7 @@ public: nStartingHeight = -1; fGetAddr = false; vfSubscribe.assign(256, false); + nMisbehavior = 0; // Be shy and don't send version until we hear if (!fInbound) @@ -568,6 +575,25 @@ public: void CancelSubscribe(unsigned int nChannel); void CloseSocketDisconnect(); void Cleanup(); + + + // Denial-of-service detection/prevention + // The idea is to detect peers that are behaving + // badly and disconnect/ban them, but do it in a + // one-coding-mistake-won't-shatter-the-entire-network + // way. + // IMPORTANT: There should be nothing I can give a + // node that it will forward on that will make that + // node's peers drop it. If there is, an attacker + // can isolate a node and/or try to split the network. + // Dropping a node for sending stuff that is invalid + // now but might be valid in a later version is also + // dangerous, because it can cause a network split + // between nodes running old code and nodes running + // new code. + static void ClearBanned(); // needed for unit testing + static bool IsBanned(unsigned int ip); + bool Misbehaving(int howmuch); // 1 == a little, 100 == a lot }; diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp new file mode 100644 index 00000000..e60bb742 --- /dev/null +++ b/src/test/DoS_tests.cpp @@ -0,0 +1,68 @@ +// +// Unit tests for denial-of-service detection/prevention code +// +#include +#include + +#include "../main.h" +#include "../net.h" +#include "../util.h" + +using namespace std; + +BOOST_AUTO_TEST_SUITE(DoS_tests) + +BOOST_AUTO_TEST_CASE(DoS_banning) +{ + CNode::ClearBanned(); + CAddress addr1(0xa0b0c001); + CNode dummyNode1(INVALID_SOCKET, addr1, true); + dummyNode1.Misbehaving(100); // Should get banned + BOOST_CHECK(CNode::IsBanned(addr1.ip)); + BOOST_CHECK(!CNode::IsBanned(addr1.ip|0x0000ff00)); // Different ip, not banned + + CAddress addr2(0xa0b0c002); + CNode dummyNode2(INVALID_SOCKET, addr2, true); + dummyNode2.Misbehaving(50); + BOOST_CHECK(!CNode::IsBanned(addr2.ip)); // 2 not banned yet... + BOOST_CHECK(CNode::IsBanned(addr1.ip)); // ... but 1 still should be + dummyNode2.Misbehaving(50); + BOOST_CHECK(CNode::IsBanned(addr2.ip)); +} + +BOOST_AUTO_TEST_CASE(DoS_banscore) +{ + CNode::ClearBanned(); + mapArgs["-banscore"] = "111"; // because 11 is my favorite number + CAddress addr1(0xa0b0c001); + CNode dummyNode1(INVALID_SOCKET, addr1, true); + dummyNode1.Misbehaving(100); + BOOST_CHECK(!CNode::IsBanned(addr1.ip)); + dummyNode1.Misbehaving(10); + BOOST_CHECK(!CNode::IsBanned(addr1.ip)); + dummyNode1.Misbehaving(1); + BOOST_CHECK(CNode::IsBanned(addr1.ip)); + mapArgs["-banscore"] = "100"; +} + +BOOST_AUTO_TEST_CASE(DoS_bantime) +{ + CNode::ClearBanned(); + int64 nStartTime = GetTime(); + SetMockTime(nStartTime); // Overrides future calls to GetTime() + + CAddress addr(0xa0b0c001); + CNode dummyNode(INVALID_SOCKET, addr, true); + + dummyNode.Misbehaving(100); + BOOST_CHECK(CNode::IsBanned(addr.ip)); + + SetMockTime(nStartTime+60*60); + BOOST_CHECK(CNode::IsBanned(addr.ip)); + + SetMockTime(nStartTime+60*60*24+1); + BOOST_CHECK(!CNode::IsBanned(addr.ip)); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp index 0230bb6e..c6f6d94b 100644 --- a/src/test/test_bitcoin.cpp +++ b/src/test/test_bitcoin.cpp @@ -8,7 +8,7 @@ #include "uint256_tests.cpp" #include "script_tests.cpp" #include "transaction_tests.cpp" - +#include "DoS_tests.cpp" CWallet* pwalletMain;