diff --git a/src/init.cpp b/src/init.cpp index d2045fd84..a1835c903 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -179,6 +179,8 @@ bool AppInit2(int argc, char* argv[]) " -addnode= \t " + _("Add a node to connect to\n") + " -connect= \t\t " + _("Connect only to the specified node\n") + " -nolisten \t " + _("Don't accept connections from outside\n") + + " -banscore= \t " + _("Threshold for disconnecting misbehaving peers (default: 100)\n") + + " -bantime= \t " + _("Number of seconds to keep misbehaving peers from reconnecting (default: 86400)\n") + #ifdef USE_UPNP #if USE_UPNP " -noupnp \t " + _("Don't attempt to use UPnP to map the listening port\n") + diff --git a/src/main.cpp b/src/main.cpp index e732ddcf5..434c8e848 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -297,24 +297,24 @@ bool CTransaction::CheckTransaction() const { // Basic checks that don't depend on any context if (vin.empty()) - return error("CTransaction::CheckTransaction() : vin empty"); + return DoS(10, error("CTransaction::CheckTransaction() : vin empty")); if (vout.empty()) - return error("CTransaction::CheckTransaction() : vout empty"); + return DoS(10, error("CTransaction::CheckTransaction() : vout empty")); // Size limits if (::GetSerializeSize(*this, SER_NETWORK) > MAX_BLOCK_SIZE) - return error("CTransaction::CheckTransaction() : size limits failed"); + return DoS(100, error("CTransaction::CheckTransaction() : size limits failed")); // Check for negative or overflow output values int64 nValueOut = 0; BOOST_FOREACH(const CTxOut& txout, vout) { if (txout.nValue < 0) - return error("CTransaction::CheckTransaction() : txout.nValue negative"); + return DoS(100, error("CTransaction::CheckTransaction() : txout.nValue negative")); if (txout.nValue > MAX_MONEY) - return error("CTransaction::CheckTransaction() : txout.nValue too high"); + return DoS(100, error("CTransaction::CheckTransaction() : txout.nValue too high")); nValueOut += txout.nValue; if (!MoneyRange(nValueOut)) - return error("CTransaction::CheckTransaction() : txout total out of range"); + return DoS(100, error("CTransaction::CheckTransaction() : txout total out of range")); } // Check for duplicate inputs @@ -329,13 +329,13 @@ bool CTransaction::CheckTransaction() const if (IsCoinBase()) { if (vin[0].scriptSig.size() < 2 || vin[0].scriptSig.size() > 100) - return error("CTransaction::CheckTransaction() : coinbase script size"); + return DoS(100, error("CTransaction::CheckTransaction() : coinbase script size")); } else { BOOST_FOREACH(const CTxIn& txin, vin) if (txin.prevout.IsNull()) - return error("CTransaction::CheckTransaction() : prevout is null"); + return DoS(10, error("CTransaction::CheckTransaction() : prevout is null")); } return true; @@ -351,7 +351,7 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi // Coinbase is only valid in a block, not as a loose transaction if (IsCoinBase()) - return error("AcceptToMemoryPool() : coinbase as individual tx"); + return DoS(100, error("AcceptToMemoryPool() : coinbase as individual tx")); // To help v0.1.5 clients who would see it as a negative number if ((int64)nLockTime > INT_MAX) @@ -364,7 +364,7 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi // 34 bytes because a TxOut is: // 20-byte address + 8 byte bitcoin amount + 5 bytes of ops + 1 byte script length if (GetSigOpCount() > nSize / 34 || nSize < 100) - return error("AcceptToMemoryPool() : nonstandard transaction"); + return DoS(10, error("AcceptToMemoryPool() : transaction with out-of-bounds SigOpCount")); // Rather not work on nonstandard transactions (unless -testnet) if (!fTestNet && !IsStandard()) @@ -855,26 +855,28 @@ bool CTransaction::ConnectInputs(CTxDB& txdb, map& mapTestPoo } if (prevout.n >= txPrev.vout.size() || prevout.n >= txindex.vSpent.size()) - return error("ConnectInputs() : %s prevout.n out of range %d %d %d prev tx %s\n%s", GetHash().ToString().substr(0,10).c_str(), prevout.n, txPrev.vout.size(), txindex.vSpent.size(), prevout.hash.ToString().substr(0,10).c_str(), txPrev.ToString().c_str()); + return DoS(100, error("ConnectInputs() : %s prevout.n out of range %d %d %d prev tx %s\n%s", GetHash().ToString().substr(0,10).c_str(), prevout.n, txPrev.vout.size(), txindex.vSpent.size(), prevout.hash.ToString().substr(0,10).c_str(), txPrev.ToString().c_str())); // If prev is coinbase, check that it's matured if (txPrev.IsCoinBase()) for (CBlockIndex* pindex = pindexBlock; pindex && pindexBlock->nHeight - pindex->nHeight < COINBASE_MATURITY; pindex = pindex->pprev) if (pindex->nBlockPos == txindex.pos.nBlockPos && pindex->nFile == txindex.pos.nFile) - return error("ConnectInputs() : tried to spend coinbase at depth %d", pindexBlock->nHeight - pindex->nHeight); + return DoS(10, error("ConnectInputs() : tried to spend coinbase at depth %d", pindexBlock->nHeight - pindex->nHeight)); // Verify signature if (!VerifySignature(txPrev, *this, i)) - return error("ConnectInputs() : %s VerifySignature failed", GetHash().ToString().substr(0,10).c_str()); + return DoS(100,error("ConnectInputs() : %s VerifySignature failed", GetHash().ToString().substr(0,10).c_str())); - // Check for conflicts + // Check for conflicts (double-spend) + // This doesn't trigger the DoS code on purpose; if it did, it would make it easier + // for an attacker to attempt to split the network. if (!txindex.vSpent[prevout.n].IsNull()) return fMiner ? false : error("ConnectInputs() : %s prev tx already used at %s", GetHash().ToString().substr(0,10).c_str(), txindex.vSpent[prevout.n].ToString().c_str()); // Check for negative or overflow input values nValueIn += txPrev.vout[prevout.n].nValue; if (!MoneyRange(txPrev.vout[prevout.n].nValue) || !MoneyRange(nValueIn)) - return error("ConnectInputs() : txin values out of range"); + return DoS(100, error("ConnectInputs() : txin values out of range")); // Mark outpoints as spent txindex.vSpent[prevout.n] = posThisTx; @@ -887,17 +889,17 @@ bool CTransaction::ConnectInputs(CTxDB& txdb, map& mapTestPoo } if (nValueIn < GetValueOut()) - return error("ConnectInputs() : %s value in < value out", GetHash().ToString().substr(0,10).c_str()); + return DoS(100, error("ConnectInputs() : %s value in < value out", GetHash().ToString().substr(0,10).c_str())); // Tally transaction fees int64 nTxFee = nValueIn - GetValueOut(); if (nTxFee < 0) - return error("ConnectInputs() : %s nTxFee < 0", GetHash().ToString().substr(0,10).c_str()); + return DoS(100, error("ConnectInputs() : %s nTxFee < 0", GetHash().ToString().substr(0,10).c_str())); if (nTxFee < nMinFee) return false; nFees += nTxFee; if (!MoneyRange(nFees)) - return error("ConnectInputs() : nFees out of range"); + return DoS(100, error("ConnectInputs() : nFees out of range")); } if (fBlock) @@ -1240,11 +1242,11 @@ bool CBlock::CheckBlock() const // Size limits if (vtx.empty() || vtx.size() > MAX_BLOCK_SIZE || ::GetSerializeSize(*this, SER_NETWORK) > MAX_BLOCK_SIZE) - return error("CheckBlock() : size limits failed"); + return DoS(100, error("CheckBlock() : size limits failed")); // Check proof of work matches claimed amount if (!CheckProofOfWork(GetHash(), nBits)) - return error("CheckBlock() : proof of work failed"); + return DoS(50, error("CheckBlock() : proof of work failed")); // Check timestamp if (GetBlockTime() > GetAdjustedTime() + 2 * 60 * 60) @@ -1252,23 +1254,23 @@ bool CBlock::CheckBlock() const // First transaction must be coinbase, the rest must not be if (vtx.empty() || !vtx[0].IsCoinBase()) - return error("CheckBlock() : first tx is not coinbase"); + return DoS(100, error("CheckBlock() : first tx is not coinbase")); for (int i = 1; i < vtx.size(); i++) if (vtx[i].IsCoinBase()) - return error("CheckBlock() : more than one coinbase"); + return DoS(100, error("CheckBlock() : more than one coinbase")); // Check transactions BOOST_FOREACH(const CTransaction& tx, vtx) if (!tx.CheckTransaction()) - return error("CheckBlock() : CheckTransaction failed"); + return DoS(tx.nDoS, error("CheckBlock() : CheckTransaction failed")); // Check that it's not full of nonstandard transactions if (GetSigOpCount() > MAX_BLOCK_SIGOPS) - return error("CheckBlock() : too many nonstandard transactions"); + return DoS(100, error("CheckBlock() : out-of-bounds SigOpCount")); // Check merkleroot if (hashMerkleRoot != BuildMerkleTree()) - return error("CheckBlock() : hashMerkleRoot mismatch"); + return DoS(100, error("CheckBlock() : hashMerkleRoot mismatch")); return true; } @@ -1283,13 +1285,13 @@ bool CBlock::AcceptBlock() // Get prev block index map::iterator mi = mapBlockIndex.find(hashPrevBlock); if (mi == mapBlockIndex.end()) - return error("AcceptBlock() : prev block not found"); + return DoS(10, error("AcceptBlock() : prev block not found")); CBlockIndex* pindexPrev = (*mi).second; int nHeight = pindexPrev->nHeight+1; // Check proof of work if (nBits != GetNextWorkRequired(pindexPrev)) - return error("AcceptBlock() : incorrect proof of work"); + return DoS(100, error("AcceptBlock() : incorrect proof of work")); // Check timestamp against prev if (GetBlockTime() <= pindexPrev->GetMedianTimePast()) @@ -1298,7 +1300,7 @@ bool CBlock::AcceptBlock() // Check that all transactions are finalized BOOST_FOREACH(const CTransaction& tx, vtx) if (!tx.IsFinal(nHeight, GetBlockTime())) - return error("AcceptBlock() : contains a non-final transaction"); + return DoS(10, error("AcceptBlock() : contains a non-final transaction")); // Check that the block chain matches the known block chain up to a checkpoint if (!fTestNet) @@ -1311,7 +1313,7 @@ bool CBlock::AcceptBlock() (nHeight == 118000 && hash != uint256("0x000000000000774a7f8a7a12dc906ddb9e17e75d684f15e00f8767f9e8f36553")) || (nHeight == 134444 && hash != uint256("0x00000000000005b12ffd4cd315cd34ffd4a594f430ac814c91184a0d42d2b0fe")) || (nHeight == 140700 && hash != uint256("0x000000000000033b512028abb90e1626d8b346fd0ed598ac0a3c371138dce2bd"))) - return error("AcceptBlock() : rejected by checkpoint lockin at %d", nHeight); + return DoS(100, error("AcceptBlock() : rejected by checkpoint lockin at %d", nHeight)); // Write block to history file if (!CheckDiskSpace(::GetSerializeSize(*this, SER_DISK))) @@ -1769,7 +1771,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) { // Each connection can only send one version message if (pfrom->nVersion != 0) + { + pfrom->Misbehaving(1); return false; + } int64 nTime; CAddress addrMe; @@ -1857,6 +1862,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) else if (pfrom->nVersion == 0) { // Must have a version message before anything else + pfrom->Misbehaving(1); return false; } @@ -1878,7 +1884,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) if (pfrom->nVersion < 31402 && mapAddresses.size() > 1000) return true; if (vAddr.size() > 1000) + { + pfrom->Misbehaving(20); return error("message addr size() = %d", vAddr.size()); + } // Store the new addresses CAddrDB addrDB; @@ -1936,7 +1945,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) vector vInv; vRecv >> vInv; if (vInv.size() > 50000) + { + pfrom->Misbehaving(20); return error("message inv size() = %d", vInv.size()); + } CTxDB txdb("r"); BOOST_FOREACH(const CInv& inv, vInv) @@ -1965,7 +1977,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) vector vInv; vRecv >> vInv; if (vInv.size() > 50000) + { + pfrom->Misbehaving(20); return error("message getdata size() = %d", vInv.size()); + } BOOST_FOREACH(const CInv& inv, vInv) { @@ -2137,6 +2152,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) printf("storing orphan tx %s\n", inv.hash.ToString().substr(0,10).c_str()); AddOrphanTx(vMsg); } + if (tx.nDoS) pfrom->Misbehaving(tx.nDoS); } @@ -2153,6 +2169,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) if (ProcessBlock(pfrom, &block)) mapAlreadyAskedFor.erase(inv); + if (block.nDoS) pfrom->Misbehaving(block.nDoS); } diff --git a/src/main.h b/src/main.h index a8deb2b92..1106bb978 100644 --- a/src/main.h +++ b/src/main.h @@ -400,6 +400,9 @@ public: std::vector vout; unsigned int nLockTime; + // Denial-of-service detection: + mutable int nDoS; + bool DoS(int nDoSIn, bool fIn) const { nDoS += nDoSIn; return fIn; } CTransaction() { @@ -421,6 +424,7 @@ public: vin.clear(); vout.clear(); nLockTime = 0; + nDoS = 0; // Denial-of-service prevention } bool IsNull() const @@ -787,6 +791,9 @@ public: // memory only mutable std::vector vMerkleTree; + // Denial-of-service detection: + mutable int nDoS; + bool DoS(int nDoSIn, bool fIn) const { nDoS += nDoSIn; return fIn; } CBlock() { @@ -820,6 +827,7 @@ public: nNonce = 0; vtx.clear(); vMerkleTree.clear(); + nDoS = 0; } bool IsNull() const diff --git a/src/net.cpp b/src/net.cpp index 2e257a6ef..1792bf78a 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -726,6 +726,52 @@ void CNode::Cleanup() } +std::map CNode::setBanned; +CCriticalSection CNode::cs_setBanned; + +void CNode::ClearBanned() +{ + setBanned.clear(); +} + +bool CNode::IsBanned(unsigned int ip) +{ + bool fResult = false; + CRITICAL_BLOCK(cs_setBanned) + { + std::map::iterator i = setBanned.find(ip); + if (i != setBanned.end()) + { + int64 t = (*i).second; + if (GetTime() < t) + fResult = true; + } + } + return fResult; +} + +bool CNode::Misbehaving(int howmuch) +{ + if (addr.IsLocal()) + { + printf("Warning: local node %s misbehaving\n", addr.ToString().c_str()); + return false; + } + + nMisbehavior += howmuch; + if (nMisbehavior >= GetArg("-banscore", 100)) + { + int64 banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban + CRITICAL_BLOCK(cs_setBanned) + if (setBanned[addr.ip] < banTime) + setBanned[addr.ip] = banTime; + CloseSocketDisconnect(); + printf("Disconnected %s for misbehavior (score=%d)\n", addr.ToString().c_str(), nMisbehavior); + return true; + } + return false; +} + @@ -896,6 +942,11 @@ void ThreadSocketHandler2(void* parg) { closesocket(hSocket); } + else if (CNode::IsBanned(addr.ip)) + { + printf("connetion from %s dropped (banned)\n", addr.ToString().c_str()); + closesocket(hSocket); + } else { printf("accepted connection %s\n", addr.ToString().c_str()); @@ -1454,7 +1505,8 @@ bool OpenNetworkConnection(const CAddress& addrConnect) // if (fShutdown) return false; - if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() || FindNode(addrConnect.ip)) + if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() || + FindNode(addrConnect.ip) || CNode::IsBanned(addrConnect.ip)) return false; vnThreadsRunning[1]--; diff --git a/src/net.h b/src/net.h index 0026e402c..5b3568fca 100644 --- a/src/net.h +++ b/src/net.h @@ -124,6 +124,13 @@ public: bool fDisconnect; protected: int nRefCount; + + // Denial-of-service detection/prevention + // Key is ip address, value is banned-until-time + static std::map setBanned; + static CCriticalSection cs_setBanned; + int nMisbehavior; + public: int64 nReleaseTime; std::map mapRequests; @@ -148,7 +155,6 @@ public: // publish and subscription std::vector vfSubscribe; - CNode(SOCKET hSocketIn, CAddress addrIn, bool fInboundIn=false) { nServices = 0; @@ -185,6 +191,7 @@ public: nStartingHeight = -1; fGetAddr = false; vfSubscribe.assign(256, false); + nMisbehavior = 0; // Be shy and don't send version until we hear if (!fInbound) @@ -568,6 +575,25 @@ public: void CancelSubscribe(unsigned int nChannel); void CloseSocketDisconnect(); void Cleanup(); + + + // Denial-of-service detection/prevention + // The idea is to detect peers that are behaving + // badly and disconnect/ban them, but do it in a + // one-coding-mistake-won't-shatter-the-entire-network + // way. + // IMPORTANT: There should be nothing I can give a + // node that it will forward on that will make that + // node's peers drop it. If there is, an attacker + // can isolate a node and/or try to split the network. + // Dropping a node for sending stuff that is invalid + // now but might be valid in a later version is also + // dangerous, because it can cause a network split + // between nodes running old code and nodes running + // new code. + static void ClearBanned(); // needed for unit testing + static bool IsBanned(unsigned int ip); + bool Misbehaving(int howmuch); // 1 == a little, 100 == a lot }; diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp new file mode 100644 index 000000000..e60bb742d --- /dev/null +++ b/src/test/DoS_tests.cpp @@ -0,0 +1,68 @@ +// +// Unit tests for denial-of-service detection/prevention code +// +#include +#include + +#include "../main.h" +#include "../net.h" +#include "../util.h" + +using namespace std; + +BOOST_AUTO_TEST_SUITE(DoS_tests) + +BOOST_AUTO_TEST_CASE(DoS_banning) +{ + CNode::ClearBanned(); + CAddress addr1(0xa0b0c001); + CNode dummyNode1(INVALID_SOCKET, addr1, true); + dummyNode1.Misbehaving(100); // Should get banned + BOOST_CHECK(CNode::IsBanned(addr1.ip)); + BOOST_CHECK(!CNode::IsBanned(addr1.ip|0x0000ff00)); // Different ip, not banned + + CAddress addr2(0xa0b0c002); + CNode dummyNode2(INVALID_SOCKET, addr2, true); + dummyNode2.Misbehaving(50); + BOOST_CHECK(!CNode::IsBanned(addr2.ip)); // 2 not banned yet... + BOOST_CHECK(CNode::IsBanned(addr1.ip)); // ... but 1 still should be + dummyNode2.Misbehaving(50); + BOOST_CHECK(CNode::IsBanned(addr2.ip)); +} + +BOOST_AUTO_TEST_CASE(DoS_banscore) +{ + CNode::ClearBanned(); + mapArgs["-banscore"] = "111"; // because 11 is my favorite number + CAddress addr1(0xa0b0c001); + CNode dummyNode1(INVALID_SOCKET, addr1, true); + dummyNode1.Misbehaving(100); + BOOST_CHECK(!CNode::IsBanned(addr1.ip)); + dummyNode1.Misbehaving(10); + BOOST_CHECK(!CNode::IsBanned(addr1.ip)); + dummyNode1.Misbehaving(1); + BOOST_CHECK(CNode::IsBanned(addr1.ip)); + mapArgs["-banscore"] = "100"; +} + +BOOST_AUTO_TEST_CASE(DoS_bantime) +{ + CNode::ClearBanned(); + int64 nStartTime = GetTime(); + SetMockTime(nStartTime); // Overrides future calls to GetTime() + + CAddress addr(0xa0b0c001); + CNode dummyNode(INVALID_SOCKET, addr, true); + + dummyNode.Misbehaving(100); + BOOST_CHECK(CNode::IsBanned(addr.ip)); + + SetMockTime(nStartTime+60*60); + BOOST_CHECK(CNode::IsBanned(addr.ip)); + + SetMockTime(nStartTime+60*60*24+1); + BOOST_CHECK(!CNode::IsBanned(addr.ip)); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp index 0230bb6ec..c6f6d94b1 100644 --- a/src/test/test_bitcoin.cpp +++ b/src/test/test_bitcoin.cpp @@ -8,7 +8,7 @@ #include "uint256_tests.cpp" #include "script_tests.cpp" #include "transaction_tests.cpp" - +#include "DoS_tests.cpp" CWallet* pwalletMain; diff --git a/src/util.cpp b/src/util.cpp index 03b3d73e6..80095fe77 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -813,11 +813,20 @@ void ShrinkDebugFile() // - Median of other nodes's clocks // - The user (asking the user to fix the system clock if the first two disagree) // +static int64 nMockTime = 0; // For unit testing + int64 GetTime() { + if (nMockTime) return nMockTime; + return time(NULL); } +void SetMockTime(int64 nMockTimeIn) +{ + nMockTime = nMockTimeIn; +} + static int64 nTimeOffset = 0; int64 GetAdjustedTime() diff --git a/src/util.h b/src/util.h index fabcaf930..dd5c41135 100644 --- a/src/util.h +++ b/src/util.h @@ -197,6 +197,7 @@ void ShrinkDebugFile(); int GetRandInt(int nMax); uint64 GetRand(uint64 nMax); int64 GetTime(); +void SetMockTime(int64 nMockTimeIn); int64 GetAdjustedTime(); void AddTimeData(unsigned int ip, int64 nTime); std::string FormatFullVersion();