From 79deab238a0946123837fe946b093c6bcdd7386e Mon Sep 17 00:00:00 2001 From: pasta Date: Mon, 31 Jan 2022 19:32:59 +0700 Subject: [PATCH] refactor: make GetRand a template, remove GetRandInt --- src/addrdb.cpp | 3 +-- src/blockencodings.cpp | 2 +- src/common/bloom.cpp | 2 +- src/init.cpp | 2 +- src/net.cpp | 2 +- src/net_processing.cpp | 8 ++++---- src/netaddress.h | 4 ++-- src/random.cpp | 7 +------ src/random.h | 13 +++++++++++-- src/test/random_tests.cpp | 8 ++++---- src/util/bytevectorhash.cpp | 6 +++--- src/util/hasher.cpp | 6 +++--- src/wallet/spend.cpp | 6 +++--- 13 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/addrdb.cpp b/src/addrdb.cpp index dad215e9689567..297bae96631e4e 100644 --- a/src/addrdb.cpp +++ b/src/addrdb.cpp @@ -47,8 +47,7 @@ template bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data, int version) { // Generate random temporary filename - uint16_t randv = 0; - GetRandBytes({(unsigned char*)&randv, sizeof(randv)}); + uint16_t randv = GetRand(); std::string tmpfn = strprintf("%s.%04x", prefix, randv); // open temp output file, and associate with CAutoFile diff --git a/src/blockencodings.cpp b/src/blockencodings.cpp index aa111b5939fa8b..2a7bf9397c8dc3 100644 --- a/src/blockencodings.cpp +++ b/src/blockencodings.cpp @@ -17,7 +17,7 @@ #include CBlockHeaderAndShortTxIDs::CBlockHeaderAndShortTxIDs(const CBlock& block, bool fUseWTXID) : - nonce(GetRand(std::numeric_limits::max())), + nonce(GetRand()), shorttxids(block.vtx.size() - 1), prefilledtxn(1), header(block) { FillShortTxIDSelector(); //TODO: Use our mempool prior to block acceptance to predictively fill more than just the coinbase diff --git a/src/common/bloom.cpp b/src/common/bloom.cpp index c744d05a0e1744..bab3708e5e2770 100644 --- a/src/common/bloom.cpp +++ b/src/common/bloom.cpp @@ -239,7 +239,7 @@ bool CRollingBloomFilter::contains(Span vKey) const void CRollingBloomFilter::reset() { - nTweak = GetRand(std::numeric_limits::max()); + nTweak = GetRand(); nEntriesThisGeneration = 0; nGeneration = 1; std::fill(data.begin(), data.end(), 0); diff --git a/src/init.cpp b/src/init.cpp index 015e17596c3829..d9740ba0262455 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -1274,7 +1274,7 @@ bool AppInitMain(NodeContext& node, interfaces::BlockAndHeaderTipInfo* tip_info) assert(!node.banman); node.banman = std::make_unique(gArgs.GetDataDirNet() / "banlist", &uiInterface, args.GetIntArg("-bantime", DEFAULT_MISBEHAVING_BANTIME)); assert(!node.connman); - node.connman = std::make_unique(GetRand(std::numeric_limits::max()), GetRand(std::numeric_limits::max()), *node.addrman, args.GetBoolArg("-networkactive", true)); + node.connman = std::make_unique(GetRand(), GetRand(), *node.addrman, args.GetBoolArg("-networkactive", true)); assert(!node.fee_estimator); // Don't initialize fee estimation with old data if we don't relay transactions, diff --git a/src/net.cpp b/src/net.cpp index be56d1e2d2572d..5effc2125e73e3 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -2115,7 +2115,7 @@ void CConnman::ThreadOpenConnections(const std::vector connect) if (fFeeler) { // Add small amount of random noise before connection to avoid synchronization. - int randsleep = GetRandInt(FEELER_SLEEP_WINDOW * 1000); + int randsleep = GetRand(FEELER_SLEEP_WINDOW * 1000); if (!interruptNet.sleep_for(std::chrono::milliseconds(randsleep))) return; LogPrint(BCLog::NET, "Making feeler connection to %s\n", addrConnect.ToString()); diff --git a/src/net_processing.cpp b/src/net_processing.cpp index a8894e6e87ff8a..e0c5377a4a84ef 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -4414,10 +4414,10 @@ void PeerManagerImpl::MaybeSendPing(CNode& node_to, Peer& peer, std::chrono::mic } if (pingSend) { - uint64_t nonce = 0; - while (nonce == 0) { - GetRandBytes({(unsigned char*)&nonce, sizeof(nonce)}); - } + uint64_t nonce; + do { + nonce = GetRand(); + } while (nonce == 0); peer.m_ping_queued = false; peer.m_ping_start = now; if (node_to.GetCommonVersion() > BIP0031_VERSION) { diff --git a/src/netaddress.h b/src/netaddress.h index b06b6c65b699ca..b77edfa3893d65 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -572,8 +572,8 @@ class CServiceHash } private: - const uint64_t m_salt_k0 = GetRand(std::numeric_limits::max()); - const uint64_t m_salt_k1 = GetRand(std::numeric_limits::max()); + const uint64_t m_salt_k0 = GetRand(); + const uint64_t m_salt_k1 = GetRand(); }; #endif // BITCOIN_NETADDRESS_H diff --git a/src/random.cpp b/src/random.cpp index fe978a6e9d1d67..17467e147e6be8 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -590,16 +590,11 @@ void RandAddEvent(const uint32_t event_info) noexcept { GetRNGState().AddEvent(e bool g_mock_deterministic_tests{false}; -uint64_t GetRand(uint64_t nMax) noexcept +uint64_t GetRandInternal(uint64_t nMax) noexcept { return FastRandomContext(g_mock_deterministic_tests).randrange(nMax); } -int GetRandInt(int nMax) noexcept -{ - return GetRand(nMax); -} - uint256 GetRandHash() noexcept { uint256 hash; diff --git a/src/random.h b/src/random.h index 7018ddfd9301af..e7d4d200c694c0 100644 --- a/src/random.h +++ b/src/random.h @@ -68,7 +68,17 @@ */ void GetRandBytes(Span bytes) noexcept; /** Generate a uniform random integer in the range [0..range). Precondition: range > 0 */ -uint64_t GetRand(uint64_t nMax) noexcept; +uint64_t GetRandInternal(uint64_t nMax) noexcept; +/** Generate a uniform random integer of type T in the range [0..nMax) + * nMax defaults to std::numeric_limits::max() + * Precondition: nMax > 0, T is an integral type, no larger than uint64_t + */ +template +T GetRand(T nMax=std::numeric_limits::max()) noexcept { + static_assert(std::is_integral(), "T must be integral"); + static_assert(std::numeric_limits::max() <= std::numeric_limits::max(), "GetRand only supports up to uint64_t"); + return T(GetRandInternal(nMax)); +} /** Generate a uniform random duration in the range [0..max). Precondition: max.count() > 0 */ template D GetRandomDuration(typename std::common_type::type max) noexcept @@ -94,7 +104,6 @@ constexpr auto GetRandMillis = GetRandomDuration; * */ std::chrono::microseconds GetExponentialRand(std::chrono::microseconds now, std::chrono::seconds average_interval); -int GetRandInt(int nMax) noexcept; uint256 GetRandHash() noexcept; /** diff --git a/src/test/random_tests.cpp b/src/test/random_tests.cpp index 978a7bee4dc389..eba7b5159221d9 100644 --- a/src/test/random_tests.cpp +++ b/src/test/random_tests.cpp @@ -26,8 +26,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) FastRandomContext ctx2(true); for (int i = 10; i > 0; --i) { - BOOST_CHECK_EQUAL(GetRand(std::numeric_limits::max()), uint64_t{10393729187455219830U}); - BOOST_CHECK_EQUAL(GetRandInt(std::numeric_limits::max()), int{769702006}); + BOOST_CHECK_EQUAL(GetRand(), uint64_t{10393729187455219830U}); + BOOST_CHECK_EQUAL(GetRand(), int{769702006}); BOOST_CHECK_EQUAL(GetRandMicros(std::chrono::hours{1}).count(), 2917185654); BOOST_CHECK_EQUAL(GetRandMillis(std::chrono::hours{1}).count(), 2144374); } @@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) // Check that a nondeterministic ones are not g_mock_deterministic_tests = false; for (int i = 10; i > 0; --i) { - BOOST_CHECK(GetRand(std::numeric_limits::max()) != uint64_t{10393729187455219830U}); - BOOST_CHECK(GetRandInt(std::numeric_limits::max()) != int{769702006}); + BOOST_CHECK(GetRand() != uint64_t{10393729187455219830U}); + BOOST_CHECK(GetRand() != int{769702006}); BOOST_CHECK(GetRandMicros(std::chrono::hours{1}) != std::chrono::microseconds{2917185654}); BOOST_CHECK(GetRandMillis(std::chrono::hours{1}) != std::chrono::milliseconds{2144374}); } diff --git a/src/util/bytevectorhash.cpp b/src/util/bytevectorhash.cpp index bc060a44c95fec..9054db4759a8b4 100644 --- a/src/util/bytevectorhash.cpp +++ b/src/util/bytevectorhash.cpp @@ -6,10 +6,10 @@ #include #include -ByteVectorHash::ByteVectorHash() +ByteVectorHash::ByteVectorHash() : + m_k0(GetRand()), + m_k1(GetRand()) { - GetRandBytes({reinterpret_cast(&m_k0), sizeof(m_k0)}); - GetRandBytes({reinterpret_cast(&m_k1), sizeof(m_k1)}); } size_t ByteVectorHash::operator()(const std::vector& input) const diff --git a/src/util/hasher.cpp b/src/util/hasher.cpp index 5900daf0500f1b..c21941eb88fab5 100644 --- a/src/util/hasher.cpp +++ b/src/util/hasher.cpp @@ -7,11 +7,11 @@ #include -SaltedTxidHasher::SaltedTxidHasher() : k0(GetRand(std::numeric_limits::max())), k1(GetRand(std::numeric_limits::max())) {} +SaltedTxidHasher::SaltedTxidHasher() : k0(GetRand()), k1(GetRand()) {} -SaltedOutpointHasher::SaltedOutpointHasher() : k0(GetRand(std::numeric_limits::max())), k1(GetRand(std::numeric_limits::max())) {} +SaltedOutpointHasher::SaltedOutpointHasher() : k0(GetRand()), k1(GetRand()) {} -SaltedSipHasher::SaltedSipHasher() : m_k0(GetRand(std::numeric_limits::max())), m_k1(GetRand(std::numeric_limits::max())) {} +SaltedSipHasher::SaltedSipHasher() : m_k0(GetRand()), m_k1(GetRand()) {} size_t SaltedSipHasher::operator()(const Span& script) const { diff --git a/src/wallet/spend.cpp b/src/wallet/spend.cpp index 3d8ae2da69d58c..961a2a4c0183a4 100644 --- a/src/wallet/spend.cpp +++ b/src/wallet/spend.cpp @@ -616,8 +616,8 @@ static uint32_t GetLocktimeForNewTransaction(interfaces::Chain& chain, const uin // that transactions that are delayed after signing for whatever reason, // e.g. high-latency mix networks and some CoinJoin implementations, have // better privacy. - if (GetRandInt(10) == 0) - locktime = std::max(0, (int)locktime - GetRandInt(100)); + if (GetRand(10) == 0) + locktime = std::max(0, (int)locktime - GetRand(100)); } else { // If our chain is lagging behind, we can't discourage fee sniping nor help // the privacy of high-latency transactions. To avoid leaking a potentially @@ -774,7 +774,7 @@ static bool CreateTransactionInternal( if (nChangePosInOut == -1) { // Insert change txn at random position: - nChangePosInOut = GetRandInt(txNew.vout.size()+1); + nChangePosInOut = GetRand(txNew.vout.size()+1); } else if ((unsigned int)nChangePosInOut > txNew.vout.size()) {