diff --git a/src/txmempool.cpp b/src/txmempool.cpp index 378123ce0febf..2bac419f84c70 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -898,6 +899,19 @@ CTxMemPool::setEntries CTxMemPool::GetIterSet(const std::set& hashes) c return ret; } +std::vector CTxMemPool::GetIterVec(const std::vector& txids) const +{ + AssertLockHeld(cs); + std::vector ret; + ret.reserve(txids.size()); + for (const auto& txid : txids) { + const auto it{GetIter(txid)}; + if (!it) return {}; + ret.push_back(*it); + } + return ret; +} + bool CTxMemPool::HasNoInputsOf(const CTransaction &tx) const { for (unsigned int i = 0; i < tx.vin.size(); i++) @@ -1127,7 +1141,6 @@ void CTxMemPool::SetLoadTried(bool load_tried) m_load_tried = load_tried; } - std::string RemovalReasonToString(const MemPoolRemovalReason& r) noexcept { switch (r) { @@ -1140,3 +1153,30 @@ std::string RemovalReasonToString(const MemPoolRemovalReason& r) noexcept } assert(false); } + +std::vector CTxMemPool::GatherClusters(const std::vector& txids) const +{ + AssertLockHeld(cs); + std::vector clustered_txs{GetIterVec(txids)}; + // Use epoch: visiting an entry means we have added it to the clustered_txs vector. It does not + // necessarily mean the entry has been processed. + WITH_FRESH_EPOCH(m_epoch); + for (const auto& it : clustered_txs) { + visited(it); + } + // i = index of where the list of entries to process starts + for (size_t i{0}; i < clustered_txs.size(); ++i) { + // DoS protection: if there are 500 or more entries to process, just quit. + if (clustered_txs.size() > 500) return {}; + const txiter& tx_iter = clustered_txs.at(i); + for (const auto& entries : {tx_iter->GetMemPoolParentsConst(), tx_iter->GetMemPoolChildrenConst()}) { + for (const CTxMemPoolEntry& entry : entries) { + const auto entry_it = mapTx.iterator_to(entry); + if (!visited(entry_it)) { + clustered_txs.push_back(entry_it); + } + } + } + } + return clustered_txs; +} diff --git a/src/txmempool.h b/src/txmempool.h index 2c3cb7e9dbd4d..769b7f69eac34 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -522,9 +522,16 @@ class CTxMemPool /** Returns an iterator to the given hash, if found */ std::optional GetIter(const uint256& txid) const EXCLUSIVE_LOCKS_REQUIRED(cs); - /** Translate a set of hashes into a set of pool iterators to avoid repeated lookups */ + /** Translate a set of hashes into a set of pool iterators to avoid repeated lookups. + * Does not require that all of the hashes correspond to actual transactions in the mempool, + * only returns the ones that exist. */ setEntries GetIterSet(const std::set& hashes) const EXCLUSIVE_LOCKS_REQUIRED(cs); + /** Translate a list of hashes into a list of mempool iterators to avoid repeated lookups. + * The nth element in txids becomes the nth element in the returned vector. If any of the txids + * don't actually exist in the mempool, returns an empty vector. */ + std::vector GetIterVec(const std::vector& txids) const EXCLUSIVE_LOCKS_REQUIRED(cs); + /** Remove a set of transactions from the mempool. * If a transaction is in this set, then all in-mempool descendants must * also be in the set, unless this transaction is being removed for being @@ -585,6 +592,12 @@ class CTxMemPool const Limits& limits, bool fSearchForParents = true) const EXCLUSIVE_LOCKS_REQUIRED(cs); + /** Collect the entire cluster of connected transactions for each transaction in txids. + * All txids must correspond to transaction entries in the mempool, otherwise this returns an + * empty vector. This call will also exit early and return an empty vector if it collects 500 or + * more transactions as a DoS protection. */ + std::vector GatherClusters(const std::vector& txids) const EXCLUSIVE_LOCKS_REQUIRED(cs); + /** Calculate all in-mempool ancestors of a set of transactions not already in the mempool and * check ancestor and descendant limits. Heuristics are used to estimate the ancestor and * descendant count of all entries if the package were to be added to the mempool. The limits