Browse Source

Deduplicate addrdb.cpp and use CHashWriter/Verifier

0.15
Pieter Wuille 8 years ago
parent
commit
cf68a488a4
  1. 212
      src/addrdb.cpp
  2. 2
      src/addrdb.h

212
src/addrdb.cpp

@ -15,25 +15,31 @@
#include "tinyformat.h" #include "tinyformat.h"
#include "util.h" #include "util.h"
namespace {
CBanDB::CBanDB() template <typename Stream, typename Data>
bool SerializeDB(Stream& stream, const Data& data)
{ {
pathBanlist = GetDataDir() / "banlist.dat"; // Write and commit header, data
try {
CHashWriter hasher(SER_DISK, CLIENT_VERSION);
stream << FLATDATA(Params().MessageStart()) << data;
hasher << FLATDATA(Params().MessageStart()) << data;
stream << hasher.GetHash();
} catch (const std::exception& e) {
return error("%s: Serialize or I/O error - %s", __func__, e.what());
} }
bool CBanDB::Write(const banmap_t& banSet) return true;
}
template <typename Data>
bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data)
{ {
// Generate random temporary filename // Generate random temporary filename
unsigned short randv = 0; unsigned short randv = 0;
GetRandBytes((unsigned char*)&randv, sizeof(randv)); GetRandBytes((unsigned char*)&randv, sizeof(randv));
std::string tmpfn = strprintf("banlist.dat.%04x", randv); std::string tmpfn = strprintf("%s.%04x", prefix, randv);
// serialize banlist, checksum data up to that point, then append csum
CDataStream ssBanlist(SER_DISK, CLIENT_VERSION);
ssBanlist << FLATDATA(Params().MessageStart());
ssBanlist << banSet;
uint256 hash = Hash(ssBanlist.begin(), ssBanlist.end());
ssBanlist << hash;
// open temp output file, and associate with CAutoFile // open temp output file, and associate with CAutoFile
fs::path pathTmp = GetDataDir() / tmpfn; fs::path pathTmp = GetDataDir() / tmpfn;
@ -42,177 +48,99 @@ bool CBanDB::Write(const banmap_t& banSet)
if (fileout.IsNull()) if (fileout.IsNull())
return error("%s: Failed to open file %s", __func__, pathTmp.string()); return error("%s: Failed to open file %s", __func__, pathTmp.string());
// Write and commit header, data // Serialize
try { if (!SerializeDB(fileout, data)) return false;
fileout << ssBanlist;
}
catch (const std::exception& e) {
return error("%s: Serialize or I/O error - %s", __func__, e.what());
}
FileCommit(fileout.Get()); FileCommit(fileout.Get());
fileout.fclose(); fileout.fclose();
// replace existing banlist.dat, if any, with new banlist.dat.XXXX // replace existing file, if any, with new file
if (!RenameOver(pathTmp, pathBanlist)) if (!RenameOver(pathTmp, path))
return error("%s: Rename-into-place failed", __func__); return error("%s: Rename-into-place failed", __func__);
return true; return true;
} }
bool CBanDB::Read(banmap_t& banSet) template <typename Stream, typename Data>
bool DeserializeDB(Stream& stream, Data& data, bool fCheckSum = true)
{ {
// open input file, and associate with CAutoFile
FILE *file = fsbridge::fopen(pathBanlist, "rb");
CAutoFile filein(file, SER_DISK, CLIENT_VERSION);
if (filein.IsNull())
return error("%s: Failed to open file %s", __func__, pathBanlist.string());
// use file size to size memory buffer
uint64_t fileSize = fs::file_size(pathBanlist);
uint64_t dataSize = 0;
// Don't try to resize to a negative number if file is small
if (fileSize >= sizeof(uint256))
dataSize = fileSize - sizeof(uint256);
std::vector<unsigned char> vchData;
vchData.resize(dataSize);
uint256 hashIn;
// read data and checksum from file
try {
filein.read((char *)&vchData[0], dataSize);
filein >> hashIn;
}
catch (const std::exception& e) {
return error("%s: Deserialize or I/O error - %s", __func__, e.what());
}
filein.fclose();
CDataStream ssBanlist(vchData, SER_DISK, CLIENT_VERSION);
// verify stored checksum matches input data
uint256 hashTmp = Hash(ssBanlist.begin(), ssBanlist.end());
if (hashIn != hashTmp)
return error("%s: Checksum mismatch, data corrupted", __func__);
unsigned char pchMsgTmp[4];
try { try {
CHashVerifier<Stream> verifier(&stream);
// de-serialize file header (network specific magic number) and .. // de-serialize file header (network specific magic number) and ..
ssBanlist >> FLATDATA(pchMsgTmp); unsigned char pchMsgTmp[4];
verifier >> FLATDATA(pchMsgTmp);
// ... verify the network matches ours // ... verify the network matches ours
if (memcmp(pchMsgTmp, Params().MessageStart(), sizeof(pchMsgTmp))) if (memcmp(pchMsgTmp, Params().MessageStart(), sizeof(pchMsgTmp)))
return error("%s: Invalid network magic number", __func__); return error("%s: Invalid network magic number", __func__);
// de-serialize ban data // de-serialize data
ssBanlist >> banSet; verifier >> data;
}
catch (const std::exception& e) {
return error("%s: Deserialize or I/O error - %s", __func__, e.what());
}
return true; // verify checksum
if (fCheckSum) {
uint256 hashTmp;
stream >> hashTmp;
if (hashTmp != verifier.GetHash()) {
return error("%s: Checksum mismatch, data corrupted", __func__);
} }
CAddrDB::CAddrDB()
{
pathAddr = GetDataDir() / "peers.dat";
} }
bool CAddrDB::Write(const CAddrMan& addr)
{
// Generate random temporary filename
unsigned short randv = 0;
GetRandBytes((unsigned char*)&randv, sizeof(randv));
std::string tmpfn = strprintf("peers.dat.%04x", randv);
// serialize addresses, checksum data up to that point, then append csum
CDataStream ssPeers(SER_DISK, CLIENT_VERSION);
ssPeers << FLATDATA(Params().MessageStart());
ssPeers << addr;
uint256 hash = Hash(ssPeers.begin(), ssPeers.end());
ssPeers << hash;
// open temp output file, and associate with CAutoFile
fs::path pathTmp = GetDataDir() / tmpfn;
FILE *file = fsbridge::fopen(pathTmp, "wb");
CAutoFile fileout(file, SER_DISK, CLIENT_VERSION);
if (fileout.IsNull())
return error("%s: Failed to open file %s", __func__, pathTmp.string());
// Write and commit header, data
try {
fileout << ssPeers;
} }
catch (const std::exception& e) { catch (const std::exception& e) {
return error("%s: Serialize or I/O error - %s", __func__, e.what()); return error("%s: Deserialize or I/O error - %s", __func__, e.what());
} }
FileCommit(fileout.Get());
fileout.fclose();
// replace existing peers.dat, if any, with new peers.dat.XXXX
if (!RenameOver(pathTmp, pathAddr))
return error("%s: Rename-into-place failed", __func__);
return true; return true;
} }
bool CAddrDB::Read(CAddrMan& addr) template <typename Data>
bool DeserializeFileDB(const fs::path& path, Data& data)
{ {
// open input file, and associate with CAutoFile // open input file, and associate with CAutoFile
FILE *file = fsbridge::fopen(pathAddr, "rb"); FILE *file = fsbridge::fopen(path, "rb");
CAutoFile filein(file, SER_DISK, CLIENT_VERSION); CAutoFile filein(file, SER_DISK, CLIENT_VERSION);
if (filein.IsNull()) if (filein.IsNull())
return error("%s: Failed to open file %s", __func__, pathAddr.string()); return error("%s: Failed to open file %s", __func__, path.string());
// use file size to size memory buffer return DeserializeDB(filein, data);
uint64_t fileSize = fs::file_size(pathAddr);
uint64_t dataSize = 0;
// Don't try to resize to a negative number if file is small
if (fileSize >= sizeof(uint256))
dataSize = fileSize - sizeof(uint256);
std::vector<unsigned char> vchData;
vchData.resize(dataSize);
uint256 hashIn;
// read data and checksum from file
try {
filein.read((char *)&vchData[0], dataSize);
filein >> hashIn;
} }
catch (const std::exception& e) {
return error("%s: Deserialize or I/O error - %s", __func__, e.what());
} }
filein.fclose();
CDataStream ssPeers(vchData, SER_DISK, CLIENT_VERSION); CBanDB::CBanDB()
{
pathBanlist = GetDataDir() / "banlist.dat";
}
// verify stored checksum matches input data bool CBanDB::Write(const banmap_t& banSet)
uint256 hashTmp = Hash(ssPeers.begin(), ssPeers.end()); {
if (hashIn != hashTmp) return SerializeFileDB("banlist", pathBanlist, banSet);
return error("%s: Checksum mismatch, data corrupted", __func__); }
return Read(addr, ssPeers); bool CBanDB::Read(banmap_t& banSet)
{
return DeserializeFileDB(pathBanlist, banSet);
} }
bool CAddrDB::Read(CAddrMan& addr, CDataStream& ssPeers) CAddrDB::CAddrDB()
{ {
unsigned char pchMsgTmp[4]; pathAddr = GetDataDir() / "peers.dat";
try { }
// de-serialize file header (network specific magic number) and ..
ssPeers >> FLATDATA(pchMsgTmp);
// ... verify the network matches ours bool CAddrDB::Write(const CAddrMan& addr)
if (memcmp(pchMsgTmp, Params().MessageStart(), sizeof(pchMsgTmp))) {
return error("%s: Invalid network magic number", __func__); return SerializeFileDB("peers", pathAddr, addr);
}
// de-serialize address data into one CAddrMan object bool CAddrDB::Read(CAddrMan& addr)
ssPeers >> addr; {
return DeserializeFileDB(pathAddr, addr);
} }
catch (const std::exception& e) {
// de-serialization has failed, ensure addrman is left in a clean state bool CAddrDB::Read(CAddrMan& addr, CDataStream& ssPeers)
{
bool ret = DeserializeDB(ssPeers, addr, false);
if (!ret) {
// Ensure addrman is left in a clean state
addr.Clear(); addr.Clear();
return error("%s: Deserialize or I/O error - %s", __func__, e.what());
} }
return ret;
return true;
} }

2
src/addrdb.h

@ -85,7 +85,7 @@ public:
CAddrDB(); CAddrDB();
bool Write(const CAddrMan& addr); bool Write(const CAddrMan& addr);
bool Read(CAddrMan& addr); bool Read(CAddrMan& addr);
bool Read(CAddrMan& addr, CDataStream& ssPeers); static bool Read(CAddrMan& addr, CDataStream& ssPeers);
}; };
/** Access to the banlist database (banlist.dat) */ /** Access to the banlist database (banlist.dat) */

Loading…
Cancel
Save