From 6df5c68f8ff7a83bad134bc4d24bb6938a8b7d44 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 6 Jul 2024 02:05:58 -0400 Subject: [PATCH 01/12] [GraphBolt] Refine S3-FIFO cache policy. --- graphbolt/src/cache_policy.cc | 27 ++++++------- graphbolt/src/cache_policy.h | 73 ++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 0c9908378931..b08c4488ae84 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -23,11 +23,16 @@ namespace graphbolt { namespace storage { S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) - : small_queue_(capacity / 10), - main_queue_(capacity - capacity / 10), - ghost_queue_time_(0), + : small_queue_(capacity), + main_queue_(capacity), + ghost_queue_(capacity - capacity / 10), capacity_(capacity), - cache_usage_(0) {} + cache_usage_(0), + small_queue_size_target_(capacity / 10) { + TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough."); + ghost_set_.reserve(ghost_queue_.Capacity()); + key_to_cache_key_.reserve(capacity); +} std::tuple S3FifoCachePolicy::Query(torch::Tensor keys) { @@ -76,23 +81,15 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { cache_key.Increment(); positions_ptr[i] = cache_key.getPos(); } else { - const auto in_ghost_queue = InGhostQueue(key); - auto& queue = in_ghost_queue ? main_queue_ : small_queue_; - int64_t pos; - if (queue.IsFull()) { - // When the queue is full, we need to make a space by evicting. - // Inside ghost queue means insertion into M, otherwise S. - pos = (in_ghost_queue ? EvictMainQueue() : EvictSmallQueue()); - } else { // If the cache is not full yet, get an unused empty slot. - pos = cache_usage_++; - } + const auto pos = Evict(); TORCH_CHECK(0 <= pos && pos < capacity_, "Position out of bounds!"); + const auto in_ghost_queue = ghost_set_.erase(key); + auto& queue = in_ghost_queue ? main_queue_ : small_queue_; key_to_cache_key_[key] = queue.Push(CacheKey(key, pos)); positions_ptr[i] = pos; } } })); - TrimGhostQueue(); return positions; } diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index c07f89ce97ac..8e1d596219e9 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -53,6 +53,12 @@ struct CircularQueue { return diff == 0 || diff == capacity_; } + auto Size() const { + auto diff = tail_ - head_; + if (diff < 0) diff += capacity_; + return diff; + } + friend std::ostream& operator<<( std::ostream& os, const CircularQueue& queue) { for (auto i = queue.head_; i != queue.tail_; queue.PostIncrement(i)) { @@ -63,7 +69,7 @@ struct CircularQueue { bool IsEmpty() const { return tail_ == head_; } - int64_t Capacity() const { return capacity_ - 1; } + auto Capacity() const { return capacity_ - 1; } private: int64_t PostIncrement(int64_t& i) const { @@ -96,6 +102,11 @@ struct CacheKey { void Decrement() { freq_ = std::max(0, static_cast(freq_ - 1)); } + CacheKey& ResetFreq() { + freq_ = 0; + return *this; + } + friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) { return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", " << key_ref.position_in_cache_ << ")"; @@ -174,7 +185,6 @@ class S3FifoCachePolicy : public BaseCachePolicy { return os << "small_queue_: " << policy.small_queue_ << "\n" << "main_queue_: " << policy.main_queue_ << "\n" << "cache_usage_: " << policy.cache_usage_ << "\n" - << "ghost_queue_time_: " << policy.ghost_queue_time_ << "\n" << "capacity_: " << policy.capacity_ << "\n"; } @@ -194,47 +204,40 @@ class S3FifoCachePolicy : public BaseCachePolicy { } int64_t EvictSmallQueue() { - auto evicted = small_queue_.Pop(); - const auto is_M_full = main_queue_.IsFull(); - auto it = key_to_cache_key_.find(evicted.getKey()); - if (evicted.getFreq() <= 0 && is_M_full) { - key_to_cache_key_.erase(it); - // No overflow is expected for any GNN workload. - ghost_map_[evicted.getKey()] = ghost_queue_time_++; - return evicted.getPos(); - } else { - const auto pos = is_M_full ? EvictMainQueue() : cache_usage_++; - it->second = main_queue_.Push(evicted); - return pos; + for (auto size = small_queue_.Size(); size > small_queue_size_target_; + size--) { + auto evicted = small_queue_.Pop(); + auto it = key_to_cache_key_.find(evicted.getKey()); + if (evicted.getFreq() <= 0) { + key_to_cache_key_.erase(it); + const auto evicted_key = evicted.getKey(); + if (ghost_queue_.IsFull()) { + ghost_set_.erase(ghost_queue_.Pop()); + } + ghost_set_.insert(evicted_key); + ghost_queue_.Push(evicted_key); + return evicted.getPos(); + } else { + it->second = main_queue_.Push(evicted.ResetFreq()); + } } + return -1; } - // Is inside the ghost queue. - bool InGhostQueue(int64_t key) const { - auto it = ghost_map_.find(key); - return it != ghost_map_.end() && - ghost_queue_time_ - it->second <= main_queue_.Capacity(); - } - - void TrimGhostQueue() { - if (static_cast(ghost_map_.size()) >= 2 * main_queue_.Capacity()) { - // Here, we ensure that the ghost_map_ does not grow too much. - phmap::priv::erase_if(ghost_map_, [&](const auto& key_value) { - const auto timestamp = key_value.second; - return ghost_queue_time_ - timestamp > main_queue_.Capacity(); - }); - } + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + const auto pos = EvictSmallQueue(); + return pos >= 0 ? pos : EvictMainQueue(); } CircularQueue small_queue_, main_queue_; - phmap::flat_hash_map key_to_cache_key_; - phmap::flat_hash_map ghost_map_; - // Keeps track of the number of insertions into the ghost queue so far. If an - // item's time of insertion is older than main_queue_.Capacity(), then it is - // not really considered in the ghost queue. - int64_t ghost_queue_time_; + CircularQueue ghost_queue_; int64_t capacity_; int64_t cache_usage_; + int64_t small_queue_size_target_; + phmap::flat_hash_set ghost_set_; + phmap::flat_hash_map key_to_cache_key_; }; } // namespace storage From f30814b2dc208bc26c18cf1e3a7a0112932f3fbd Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 6 Jul 2024 02:42:29 -0400 Subject: [PATCH 02/12] fix the tests and simplify query API. --- python/dgl/graphbolt/impl/feature_cache.py | 11 ++++++----- .../pytorch/graphbolt/impl/test_feature_cache.py | 10 ++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 8922a8a27ebe..9fbd5e32f7e9 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -30,26 +30,27 @@ def __init__(self, cache_shape, dtype, num_parts=1, policy="s3-fifo"): self.total_miss = 0 self.total_queries = 0 - def query(self, keys, pin_memory=False): + def query(self, keys): """Queries the cache. Parameters ---------- keys : Tensor The keys to query the cache with. - pin_memory : bool, optional - Whether the output values tensor should be pinned. Default is False. Returns ------- tuple(Tensor, Tensor, Tensor) A tuple containing (values, missing_indices, missing_keys) where values[missing_indices] corresponds to cache misses that should be - filled by quering another source with missing_keys. + filled by quering another source with missing_keys. If keys is + pinned, then the returned values tensor is pinned as well. """ self.total_queries += keys.shape[0] positions, index, missing_keys = self._policy.query(keys) - values = self._cache.query(positions, index, keys.shape[0], pin_memory) + values = self._cache.query( + positions, index, keys.shape[0], keys.is_pinned() + ) self.total_miss += missing_keys.shape[0] missing_index = index[positions.size(0) :] return values, missing_index, missing_keys diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index bd5d376449d9..8e28e7dba062 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -31,7 +31,7 @@ def test_feature_cache(dtype, feature_size, num_parts): ) keys = torch.tensor([0, 1]) - values, missing_index, missing_keys = cache.query(keys, False) + values, missing_index, missing_keys = cache.query(keys) assert torch.equal( missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], keys, @@ -44,8 +44,8 @@ def test_feature_cache(dtype, feature_size, num_parts): pin_memory = F._default_context_str == "gpu" - keys = torch.arange(1, 33) - values, missing_index, missing_keys = cache.query(keys, pin_memory) + keys = torch.arange(1, 33, pin_memory=pin_memory) + values, missing_index, missing_keys = cache.query(keys) assert torch.equal( missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], torch.arange(2, 33), @@ -58,9 +58,7 @@ def test_feature_cache(dtype, feature_size, num_parts): assert torch.equal(values, a[keys]) values, missing_index, missing_keys = cache.query(keys) - assert torch.equal( - missing_keys.flip([0]), torch.tensor([5] if num_parts == 1 else []) - ) + assert torch.equal(missing_keys.flip([0]), torch.tensor([])) missing_values = a[missing_keys] cache.replace(missing_keys, missing_values) From 955b55a7362ca871c5bcdd3980fd62535b1c3d94 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 04:38:37 -0400 Subject: [PATCH 03/12] Add SIEVE cache policy and refactor. --- graphbolt/src/cache_policy.cc | 79 +++++++----- graphbolt/src/cache_policy.h | 113 ++++++++++++++++++ graphbolt/src/partitioned_cache_policy.cc | 2 + graphbolt/src/python_binding.cc | 3 + python/dgl/graphbolt/impl/feature_cache.py | 9 +- .../graphbolt/impl/test_feature_cache.py | 5 +- 6 files changed, 175 insertions(+), 36 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index b08c4488ae84..c9ac7f66b0c4 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -22,20 +22,9 @@ namespace graphbolt { namespace storage { -S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) - : small_queue_(capacity), - main_queue_(capacity), - ghost_queue_(capacity - capacity / 10), - capacity_(capacity), - cache_usage_(0), - small_queue_size_target_(capacity / 10) { - TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough."); - ghost_set_.reserve(ghost_queue_.Capacity()); - key_to_cache_key_.reserve(capacity); -} - +template std::tuple -S3FifoCachePolicy::Query(torch::Tensor keys) { +BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); auto missing_keys = torch::empty_like(keys); @@ -49,11 +38,9 @@ S3FifoCachePolicy::Query(torch::Tensor keys) { auto missing_keys_ptr = missing_keys.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - auto it = key_to_cache_key_.find(key); - if (it != key_to_cache_key_.end()) { - auto& cache_key = *it->second; - cache_key.Increment(); - positions_ptr[found_cnt] = cache_key.getPos(); + auto pos = policy.Read(key); + if (pos.has_value()) { + positions_ptr[found_cnt] = *pos; indices_ptr[found_cnt++] = i; } else { indices_ptr[--missing_cnt] = i; @@ -66,7 +53,9 @@ S3FifoCachePolicy::Query(torch::Tensor keys) { missing_keys.slice(0, found_cnt)}; } -torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { +template +torch::Tensor BaseCachePolicy::ReplaceImpl( + CachePolicy& policy, torch::Tensor keys) { auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); AT_DISPATCH_INDEX_TYPES( keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] { @@ -74,24 +63,52 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { auto positions_ptr = positions.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - auto it = key_to_cache_key_.find(key); - if (it != - key_to_cache_key_.end()) { // Already in the cache, inc freq. - auto& cache_key = *it->second; - cache_key.Increment(); - positions_ptr[i] = cache_key.getPos(); + const auto pos = policy.Read(key); + if (pos.has_value()) { // Already in the cache, inc freq. + positions_ptr[i] = *pos; } else { - const auto pos = Evict(); - TORCH_CHECK(0 <= pos && pos < capacity_, "Position out of bounds!"); - const auto in_ghost_queue = ghost_set_.erase(key); - auto& queue = in_ghost_queue ? main_queue_ : small_queue_; - key_to_cache_key_[key] = queue.Push(CacheKey(key, pos)); - positions_ptr[i] = pos; + positions_ptr[i] = policy.Insert(key); } } })); return positions; } +S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) + : small_queue_(capacity), + main_queue_(capacity), + ghost_queue_(capacity - capacity / 10), + capacity_(capacity), + cache_usage_(0), + small_queue_size_target_(capacity / 10) { + TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough."); + ghost_set_.reserve(ghost_queue_.Capacity()); + key_to_cache_key_.reserve(capacity); +} + +std::tuple +S3FifoCachePolicy::Query(torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + +SieveCachePolicy::SieveCachePolicy(int64_t capacity) + : hand_(queue_.end()), capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple SieveCachePolicy::Query( + torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 8e1d596219e9..5b632943e5a3 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -102,6 +102,11 @@ struct CacheKey { void Decrement() { freq_ = std::max(0, static_cast(freq_ - 1)); } + CacheKey& SetFreq() { + freq_ = 1; + return *this; + } + CacheKey& ResetFreq() { freq_ = 0; return *this; @@ -141,6 +146,13 @@ class BaseCachePolicy { * entries in the cache. */ virtual torch::Tensor Replace(torch::Tensor keys) = 0; + + template + static std::tuple QueryImpl( + CachePolicy& policy, torch::Tensor keys); + + template + static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys); }; /** @@ -188,6 +200,24 @@ class S3FifoCachePolicy : public BaseCachePolicy { << "capacity_: " << policy.capacity_ << "\n"; } + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + cache_key.Increment(); + return cache_key.getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + const auto in_ghost_queue = ghost_set_.erase(key); + auto& queue = in_ghost_queue ? main_queue_ : small_queue_; + key_to_cache_key_[key] = queue.Push(CacheKey(key, pos)); + return pos; + } + private: int64_t EvictMainQueue() { while (true) { @@ -240,6 +270,89 @@ class S3FifoCachePolicy : public BaseCachePolicy { phmap::flat_hash_map key_to_cache_key_; }; +/** + * @brief SIEVE is a simple, scalable FIFObased algorithm with a single static + * queue. https://www.usenix.org/system/files/nsdi24-zhang-yazhuo.pdf + **/ +class SieveCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the S3FifoCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + SieveCachePolicy(int64_t capacity); + + SieveCachePolicy() = default; + + /** + * @brief The policy query function. + * @param keys The keys to query the cache. + * + * @return (positions, indices, missing_keys), where positions has the + * locations of the keys which were found in the cache, missing_keys has the + * keys that were not found and indices is defined such that + * keys[indices[:positions.size(0)]] gives us the found keys and + * keys[indices[positions.size(0):]] is identical to missing_keys. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief The policy replace function. + * @param keys The keys to query the cache. + * + * @return positions tensor is returned holding the locations of the + * replaced entries in the cache. + */ + torch::Tensor Replace(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + cache_key.SetFreq(); + return cache_key.getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + queue_.push_front(CacheKey(key, pos)); + key_to_cache_key_[key] = queue_.begin(); + return pos; + } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + --hand_; + while (hand_->getFreq()) { + hand_->ResetFreq(); + if (hand_ == queue_.begin()) hand_ = queue_.end(); + --hand_; + } + TORCH_CHECK(key_to_cache_key_.erase(hand_->getKey())); + const auto pos = hand_->getPos(); + const auto temp = hand_; + if (hand_ == queue_.begin()) { + hand_ = queue_.end(); + } else { + ++hand_; + } + queue_.erase(temp); + return pos; + } + + std::list queue_; + decltype(queue_)::iterator hand_; + int64_t capacity_; + int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; +}; + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 04e735341324..c35a16ff939b 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -210,6 +210,8 @@ c10::intrusive_ptr PartitionedCachePolicy::Create( template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index d1d25d771e90..d6d5df9c2769 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -103,6 +103,9 @@ TORCH_LIBRARY(graphbolt, m) { m.def( "s3_fifo_cache_policy", &storage::PartitionedCachePolicy::Create); + m.def( + "sieve_cache_policy", + &storage::PartitionedCachePolicy::Create); m.class_("FeatureCache") .def("query", &storage::FeatureCache::Query) .def("replace", &storage::FeatureCache::Replace); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 9fbd5e32f7e9..3fc36e309807 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -3,7 +3,10 @@ __all__ = ["FeatureCache"] -caching_policies = {"s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy} +caching_policies = { + "s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy, + "sieve": torch.ops.graphbolt.sieve_cache_policy, +} class FeatureCache(object): @@ -18,10 +21,10 @@ class FeatureCache(object): num_parts: int, optional The number of cache partitions for parallelism. Default is 1. policy: str, optional - The cache policy to be used. Default is "s3-fifo". + The cache policy. Default is "sieve". "s3-fifo" is also available. """ - def __init__(self, cache_shape, dtype, num_parts=1, policy="s3-fifo"): + def __init__(self, cache_shape, dtype, num_parts=1, policy="sieve"): assert ( policy in caching_policies ), f"{list(caching_policies.keys())} are the available caching policies." diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index 8e28e7dba062..0395c26a21d7 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -23,11 +23,12 @@ ) @pytest.mark.parametrize("feature_size", [2, 16]) @pytest.mark.parametrize("num_parts", [1, 2]) -def test_feature_cache(dtype, feature_size, num_parts): +@pytest.mark.parametrize("policy", ["s3-fifo", "sieve"]) +def test_feature_cache(dtype, feature_size, num_parts, policy): cache_size = 32 * num_parts a = torch.randint(0, 2, [1024, feature_size], dtype=dtype) cache = gb.impl.FeatureCache( - (cache_size,) + a.shape[1:], a.dtype, num_parts + (cache_size,) + a.shape[1:], a.dtype, num_parts, policy ) keys = torch.tensor([0, 1]) From 18ff8306587dc523e3dffb981beb4c4ee40c21c8 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 05:29:24 -0400 Subject: [PATCH 04/12] fix the SieveCachePolicy comment. --- graphbolt/src/cache_policy.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 5b632943e5a3..f06c4b3f5be3 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -277,7 +277,7 @@ class S3FifoCachePolicy : public BaseCachePolicy { class SieveCachePolicy : public BaseCachePolicy { public: /** - * @brief Constructor for the S3FifoCachePolicy class. + * @brief Constructor for the SieveCachePolicy class. * * @param capacity The capacity of the cache in terms of # elements. */ From f0f27ba17e2f8fdf021888f16390a858d08af008 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 05:44:21 -0400 Subject: [PATCH 05/12] add LRU. --- graphbolt/src/cache_policy.cc | 15 ++++ graphbolt/src/cache_policy.h | 73 +++++++++++++++++++ graphbolt/src/partitioned_cache_policy.cc | 2 + graphbolt/src/python_binding.cc | 3 + python/dgl/graphbolt/impl/feature_cache.py | 4 +- .../graphbolt/impl/test_feature_cache.py | 2 +- 6 files changed, 97 insertions(+), 2 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index c9ac7f66b0c4..43bb770b6898 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -110,5 +110,20 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +LruCachePolicy::LruCachePolicy(int64_t capacity) + : capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple LruCachePolicy::Query( + torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index f06c4b3f5be3..90006b9d8502 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -353,6 +353,79 @@ class SieveCachePolicy : public BaseCachePolicy { phmap::flat_hash_map key_to_cache_key_; }; +/** + * @brief LeastRecentlyUsed is a simple, scalable FIFObased algorithm with a + * single static queue. + **/ +class LruCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the LruCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + LruCachePolicy(int64_t capacity); + + LruCachePolicy() = default; + + /** + * @brief The policy query function. + * @param keys The keys to query the cache. + * + * @return (positions, indices, missing_keys), where positions has the + * locations of the keys which were found in the cache, missing_keys has the + * keys that were not found and indices is defined such that + * keys[indices[:positions.size(0)]] gives us the found keys and + * keys[indices[positions.size(0):]] is identical to missing_keys. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief The policy replace function. + * @param keys The keys to query the cache. + * + * @return positions tensor is returned holding the locations of the + * replaced entries in the cache. + */ + torch::Tensor Replace(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + const auto cache_key = *it->second; + queue_.erase(it->second); + queue_.push_front(cache_key); + it->second = queue_.begin(); + return cache_key.getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + queue_.push_front(CacheKey(key, pos)); + key_to_cache_key_[key] = queue_.begin(); + return pos; + } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + const auto& cache_key = queue_.back(); + TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); + const auto pos = cache_key.getPos(); + queue_.pop_back(); + return pos; + } + + std::list queue_; + int64_t capacity_; + int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; +}; + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index c35a16ff939b..e486e7b58941 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -212,6 +212,8 @@ template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index d6d5df9c2769..943b35843628 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -106,6 +106,9 @@ TORCH_LIBRARY(graphbolt, m) { m.def( "sieve_cache_policy", &storage::PartitionedCachePolicy::Create); + m.def( + "lru_cache_policy", + &storage::PartitionedCachePolicy::Create); m.class_("FeatureCache") .def("query", &storage::FeatureCache::Query) .def("replace", &storage::FeatureCache::Replace); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 3fc36e309807..05c25e310955 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -6,6 +6,7 @@ caching_policies = { "s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy, "sieve": torch.ops.graphbolt.sieve_cache_policy, + "lru": torch.ops.graphbolt.lru_cache_policy, } @@ -21,7 +22,8 @@ class FeatureCache(object): num_parts: int, optional The number of cache partitions for parallelism. Default is 1. policy: str, optional - The cache policy. Default is "sieve". "s3-fifo" is also available. + The cache policy. Default is "sieve". "s3-fifo" and "lru" are also + available. """ def __init__(self, cache_shape, dtype, num_parts=1, policy="sieve"): diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index 0395c26a21d7..e8fb87c4a0ea 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -23,7 +23,7 @@ ) @pytest.mark.parametrize("feature_size", [2, 16]) @pytest.mark.parametrize("num_parts", [1, 2]) -@pytest.mark.parametrize("policy", ["s3-fifo", "sieve"]) +@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru"]) def test_feature_cache(dtype, feature_size, num_parts, policy): cache_size = 32 * num_parts a = torch.randint(0, 2, [1024, feature_size], dtype=dtype) From fb8eaf803d5062f473e0198f30ab87348e1d2499 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 06:12:03 -0400 Subject: [PATCH 06/12] Add clock caching policy --- graphbolt/src/cache_policy.cc | 15 ++++ graphbolt/src/cache_policy.h | 81 +++++++++++++++++++ graphbolt/src/partitioned_cache_policy.cc | 2 + graphbolt/src/python_binding.cc | 3 + python/dgl/graphbolt/impl/feature_cache.py | 5 +- .../graphbolt/impl/test_feature_cache.py | 2 +- 6 files changed, 105 insertions(+), 3 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 43bb770b6898..1b084198c1f8 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -125,5 +125,20 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +ClockCachePolicy::ClockCachePolicy(int64_t capacity) + : capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple ClockCachePolicy::Query( + torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 90006b9d8502..5ade1f21330c 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -426,6 +426,87 @@ class LruCachePolicy : public BaseCachePolicy { phmap::flat_hash_map key_to_cache_key_; }; +/** + * @brief Clock (FIFO-Reinsertion) is a simple, scalable FIFObased algorithm + * with a single static queue. + * https://people.csail.mit.edu/saltzer/Multics/MHP-Saltzer-060508/bookcases/M00s/M0104%20074-12).PDF + **/ +class ClockCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the ClockCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + ClockCachePolicy(int64_t capacity); + + ClockCachePolicy() = default; + + /** + * @brief The policy query function. + * @param keys The keys to query the cache. + * + * @return (positions, indices, missing_keys), where positions has the + * locations of the keys which were found in the cache, missing_keys has the + * keys that were not found and indices is defined such that + * keys[indices[:positions.size(0)]] gives us the found keys and + * keys[indices[positions.size(0):]] is identical to missing_keys. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief The policy replace function. + * @param keys The keys to query the cache. + * + * @return positions tensor is returned holding the locations of the + * replaced entries in the cache. + */ + torch::Tensor Replace(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + cache_key.SetFreq(); + return cache_key.getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + queue_.push_front(CacheKey(key, pos)); + key_to_cache_key_[key] = queue_.begin(); + return pos; + } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + while (true) { + auto& cache_key = queue_.back(); + if (cache_key.getFreq()) { + queue_.push_front(cache_key.ResetFreq()); + key_to_cache_key_[cache_key.getKey()] = queue_.begin(); + queue_.pop_back(); + } else + break; + } + const auto& cache_key = queue_.back(); + TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); + const auto pos = cache_key.getPos(); + queue_.pop_back(); + return pos; + } + + std::list queue_; + int64_t capacity_; + int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; +}; + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index e486e7b58941..5c98a90103df 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -214,6 +214,8 @@ template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 943b35843628..6cbe7df6a717 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -109,6 +109,9 @@ TORCH_LIBRARY(graphbolt, m) { m.def( "lru_cache_policy", &storage::PartitionedCachePolicy::Create); + m.def( + "clock_cache_policy", + &storage::PartitionedCachePolicy::Create); m.class_("FeatureCache") .def("query", &storage::FeatureCache::Query) .def("replace", &storage::FeatureCache::Replace); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 05c25e310955..dca2ac2fadfd 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -7,6 +7,7 @@ "s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy, "sieve": torch.ops.graphbolt.sieve_cache_policy, "lru": torch.ops.graphbolt.lru_cache_policy, + "clock": torch.ops.graphbolt.clock_cache_policy, } @@ -22,8 +23,8 @@ class FeatureCache(object): num_parts: int, optional The number of cache partitions for parallelism. Default is 1. policy: str, optional - The cache policy. Default is "sieve". "s3-fifo" and "lru" are also - available. + The cache policy. Default is "sieve". "s3-fifo", "lru" and "clock" are + also available. """ def __init__(self, cache_shape, dtype, num_parts=1, policy="sieve"): diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index e8fb87c4a0ea..ee0f5d55aa6c 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -23,7 +23,7 @@ ) @pytest.mark.parametrize("feature_size", [2, 16]) @pytest.mark.parametrize("num_parts", [1, 2]) -@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru"]) +@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"]) def test_feature_cache(dtype, feature_size, num_parts, policy): cache_size = 32 * num_parts a = torch.randint(0, 2, [1024, feature_size], dtype=dtype) From 5247dab1bc45691876b0c7ef3784e0683c22f291 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 06:18:33 -0400 Subject: [PATCH 07/12] minor improvement. --- graphbolt/src/cache_policy.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 1b084198c1f8..c3b02efa89d5 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -64,11 +64,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl( for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; const auto pos = policy.Read(key); - if (pos.has_value()) { // Already in the cache, inc freq. - positions_ptr[i] = *pos; - } else { - positions_ptr[i] = policy.Insert(key); - } + positions_ptr[i] = pos.has_value() ? *pos : policy.Insert(key); } })); return positions; From 237268e80d1d97b9d98f2d7bd8d6b1061c7576cc Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 12:43:49 -0400 Subject: [PATCH 08/12] Switch to CircularQueue for `CLOCK`. --- graphbolt/src/cache_policy.h | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 5ade1f21330c..88f127d63e01 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -476,8 +476,7 @@ class ClockCachePolicy : public BaseCachePolicy { int64_t Insert(int64_t key) { const auto pos = Evict(); - queue_.push_front(CacheKey(key, pos)); - key_to_cache_key_[key] = queue_.begin(); + key_to_cache_key_[key] = queue_.Push(CacheKey(key, pos)); return pos; } @@ -485,26 +484,23 @@ class ClockCachePolicy : public BaseCachePolicy { int64_t Evict() { // If the cache has space, get an unused slot otherwise perform eviction. if (cache_usage_ < capacity_) return cache_usage_++; + CacheKey cache_key; while (true) { - auto& cache_key = queue_.back(); + cache_key = queue_.Pop(); if (cache_key.getFreq()) { - queue_.push_front(cache_key.ResetFreq()); - key_to_cache_key_[cache_key.getKey()] = queue_.begin(); - queue_.pop_back(); + key_to_cache_key_[cache_key.getKey()] = + queue_.Push(cache_key.ResetFreq()); } else break; } - const auto& cache_key = queue_.back(); TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); - const auto pos = cache_key.getPos(); - queue_.pop_back(); - return pos; + return cache_key.getPos(); } - std::list queue_; + CircularQueue queue_; int64_t capacity_; int64_t cache_usage_; - phmap::flat_hash_map key_to_cache_key_; + phmap::flat_hash_map key_to_cache_key_; }; } // namespace storage From f70a3a236d9f46f2dfbb2405bdcf7fe005f562ae Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 12:48:10 -0400 Subject: [PATCH 09/12] initialize CircularQueue. --- graphbolt/src/cache_policy.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index c3b02efa89d5..13f1d36107a2 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -122,7 +122,7 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { } ClockCachePolicy::ClockCachePolicy(int64_t capacity) - : capacity_(capacity), cache_usage_(0) { + : queue_(capacity), capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); key_to_cache_key_.reserve(capacity); } From 1f1332c5a03c5ee0686017f3c5ee448652a29bf5 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 17:02:11 -0400 Subject: [PATCH 10/12] make output pinned if input is pinned. --- graphbolt/src/cache_policy.cc | 25 ++++++++++++++++------ graphbolt/src/feature_cache.cc | 8 ++----- graphbolt/src/feature_cache.h | 3 +-- graphbolt/src/partitioned_cache_policy.cc | 14 +++++++----- python/dgl/graphbolt/impl/feature_cache.py | 4 +--- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 13f1d36107a2..62c92ec901cc 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -25,9 +25,14 @@ namespace storage { template std::tuple BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { - auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); - auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); - auto missing_keys = torch::empty_like(keys); + auto positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); + auto indices = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); + auto missing_keys = + torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned())); int64_t found_cnt = 0; int64_t missing_cnt = keys.size(0); AT_DISPATCH_INDEX_TYPES( @@ -56,15 +61,23 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { template torch::Tensor BaseCachePolicy::ReplaceImpl( CachePolicy& policy, torch::Tensor keys) { - auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); + auto positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); AT_DISPATCH_INDEX_TYPES( keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] { auto keys_ptr = keys.data_ptr(); auto positions_ptr = positions.data_ptr(); + phmap::flat_hash_set position_set; + position_set.reserve(keys.size(0)); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - const auto pos = policy.Read(key); - positions_ptr[i] = pos.has_value() ? *pos : policy.Insert(key); + const auto pos_optional = policy.Read(key); + const auto pos = pos_optional ? *pos_optional : policy.Insert(key); + positions_ptr[i] = pos; + TORCH_CHECK( + std::get<1>(position_set.insert(pos)), + "Can't insert all, larger cache capacity is needed."); } })); return positions; diff --git a/graphbolt/src/feature_cache.cc b/graphbolt/src/feature_cache.cc index ffe87f8bd377..c532a0efe79d 100644 --- a/graphbolt/src/feature_cache.cc +++ b/graphbolt/src/feature_cache.cc @@ -29,8 +29,8 @@ FeatureCache::FeatureCache( : tensor_(torch::empty(shape, c10::TensorOptions().dtype(dtype))) {} torch::Tensor FeatureCache::Query( - torch::Tensor positions, torch::Tensor indices, int64_t size, - bool pin_memory) { + torch::Tensor positions, torch::Tensor indices, int64_t size) { + const bool pin_memory = positions.is_pinned() || indices.is_pinned(); std::vector output_shape{ tensor_.sizes().begin(), tensor_.sizes().end()}; output_shape[0] = size; @@ -53,10 +53,6 @@ torch::Tensor FeatureCache::Query( } void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) { - if (positions.size(0) > tensor_.size(0)) { - positions = positions.slice(0, 0, tensor_.size(0)); - values = values.slice(0, 0, tensor_.size(0)); - } const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size(); auto values_ptr = reinterpret_cast(values.data_ptr()); const auto tensor_ptr = reinterpret_cast(tensor_.data_ptr()); diff --git a/graphbolt/src/feature_cache.h b/graphbolt/src/feature_cache.h index c97b3ef7848e..880e7bb37cdf 100644 --- a/graphbolt/src/feature_cache.h +++ b/graphbolt/src/feature_cache.h @@ -55,8 +55,7 @@ struct FeatureCache : public torch::CustomClassHolder { * is true. */ torch::Tensor Query( - torch::Tensor positions, torch::Tensor indices, int64_t size, - bool pin_memory); + torch::Tensor positions, torch::Tensor indices, int64_t size); /** * @brief The cache replace function. diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 5c98a90103df..21cb8ef4928f 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -138,10 +138,13 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { result_offsets, result_offsets + result_offsets_tensor.size(0), result_offsets, 0); torch::Tensor positions = torch::empty( - result_offsets[policies_.size()], std::get<0>(results[0]).options()); - torch::Tensor output_indices = torch::empty_like(indices); + result_offsets[policies_.size()], + std::get<0>(results[0]).options().pinned_memory(keys.is_pinned())); + torch::Tensor output_indices = torch::empty_like( + indices, indices.options().pinned_memory(keys.is_pinned())); torch::Tensor missing_keys = torch::empty( - indices.size(0) - positions.size(0), std::get<2>(results[0]).options()); + indices.size(0) - positions.size(0), + std::get<2>(results[0]).options().pinned_memory(keys.is_pinned())); auto output_indices_ptr = output_indices.data_ptr(); torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { if (begin == end) return; @@ -179,8 +182,9 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) { if (policies_.size() == 1) return policies_[0]->Replace(keys); torch::Tensor offsets, indices, permuted_keys; std::tie(offsets, indices, permuted_keys) = Partition(keys); - auto output_positions = - torch::empty_like(keys, keys.options().dtype(torch::kInt64)); + auto output_positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); auto offsets_ptr = offsets.data_ptr(); auto indices_ptr = indices.data_ptr(); auto output_positions_ptr = output_positions.data_ptr(); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index dca2ac2fadfd..773f1a28256f 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -54,9 +54,7 @@ def query(self, keys): """ self.total_queries += keys.shape[0] positions, index, missing_keys = self._policy.query(keys) - values = self._cache.query( - positions, index, keys.shape[0], keys.is_pinned() - ) + values = self._cache.query(positions, index, keys.shape[0]) self.total_miss += missing_keys.shape[0] missing_index = index[positions.size(0) :] return values, missing_index, missing_keys From c312794b3593cd58a2ced2414bec0ba8581808ec Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 8 Jul 2024 21:13:39 -0400 Subject: [PATCH 11/12] Mark items currently being read as unevictable. --- graphbolt/src/cache_policy.cc | 57 +++++-- graphbolt/src/cache_policy.h | 178 +++++++++++---------- graphbolt/src/partitioned_cache_policy.cc | 34 +++- graphbolt/src/partitioned_cache_policy.h | 8 +- graphbolt/src/python_binding.cc | 5 +- python/dgl/graphbolt/impl/feature_cache.py | 3 +- 6 files changed, 183 insertions(+), 102 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 62c92ec901cc..94324d62e2c8 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -23,7 +23,7 @@ namespace graphbolt { namespace storage { template -std::tuple +std::tuple BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { auto positions = torch::empty_like( keys, @@ -31,31 +31,32 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { auto indices = torch::empty_like( keys, keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); - auto missing_keys = + auto filtered_keys = torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned())); int64_t found_cnt = 0; int64_t missing_cnt = keys.size(0); AT_DISPATCH_INDEX_TYPES( - keys.scalar_type(), "S3FifoCachePolicy::Query::DispatchForKeys", ([&] { + keys.scalar_type(), "BaseCachePolicy::Query::DispatchForKeys", ([&] { auto keys_ptr = keys.data_ptr(); auto positions_ptr = positions.data_ptr(); auto indices_ptr = indices.data_ptr(); - auto missing_keys_ptr = missing_keys.data_ptr(); + auto filtered_keys_ptr = filtered_keys.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; auto pos = policy.Read(key); if (pos.has_value()) { positions_ptr[found_cnt] = *pos; + filtered_keys_ptr[found_cnt] = key; indices_ptr[found_cnt++] = i; } else { indices_ptr[--missing_cnt] = i; - missing_keys_ptr[missing_cnt] = key; + filtered_keys_ptr[missing_cnt] = key; } } })); return { positions.slice(0, 0, found_cnt), indices, - missing_keys.slice(0, found_cnt)}; + filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)}; } template @@ -65,7 +66,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl( keys, keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); AT_DISPATCH_INDEX_TYPES( - keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] { + keys.scalar_type(), "BaseCachePolicy::Replace", ([&] { auto keys_ptr = keys.data_ptr(); auto positions_ptr = positions.data_ptr(); phmap::flat_hash_set position_set; @@ -83,6 +84,18 @@ torch::Tensor BaseCachePolicy::ReplaceImpl( return positions; } +template +void BaseCachePolicy::ReadingCompletedImpl( + CachePolicy& policy, torch::Tensor keys) { + AT_DISPATCH_INDEX_TYPES( + keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] { + auto keys_ptr = keys.data_ptr(); + for (int64_t i = 0; i < keys.size(0); i++) { + policy.Unmark(keys_ptr[i]); + } + })); +} + S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) : small_queue_(capacity), main_queue_(capacity), @@ -95,7 +108,7 @@ S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) key_to_cache_key_.reserve(capacity); } -std::tuple +std::tuple S3FifoCachePolicy::Query(torch::Tensor keys) { return QueryImpl(*this, keys); } @@ -104,14 +117,18 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + SieveCachePolicy::SieveCachePolicy(int64_t capacity) : hand_(queue_.end()), capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); key_to_cache_key_.reserve(capacity); } -std::tuple SieveCachePolicy::Query( - torch::Tensor keys) { +std::tuple +SieveCachePolicy::Query(torch::Tensor keys) { return QueryImpl(*this, keys); } @@ -119,14 +136,18 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + LruCachePolicy::LruCachePolicy(int64_t capacity) : capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); key_to_cache_key_.reserve(capacity); } -std::tuple LruCachePolicy::Query( - torch::Tensor keys) { +std::tuple +LruCachePolicy::Query(torch::Tensor keys) { return QueryImpl(*this, keys); } @@ -134,14 +155,18 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +void LruCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + ClockCachePolicy::ClockCachePolicy(int64_t capacity) : queue_(capacity), capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); key_to_cache_key_.reserve(capacity); } -std::tuple ClockCachePolicy::Query( - torch::Tensor keys) { +std::tuple +ClockCachePolicy::Query(torch::Tensor keys) { return QueryImpl(*this, keys); } @@ -149,5 +174,9 @@ torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) { return ReplaceImpl(*this, keys); } +void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 88f127d63e01..1140a14420f5 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -86,7 +86,7 @@ struct CircularQueue { struct CacheKey { CacheKey(int64_t key, int64_t position) - : freq_(0), key_(key), position_in_cache_(position) { + : freq_(0), key_(key), position_in_cache_(position), reference_count_(0) { static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t)); } @@ -98,9 +98,15 @@ struct CacheKey { auto getPos() const { return position_in_cache_; } - void Increment() { freq_ = std::min(3, static_cast(freq_ + 1)); } + CacheKey& Increment() { + freq_ = std::min(3, static_cast(freq_ + 1)); + return *this; + } - void Decrement() { freq_ = std::max(0, static_cast(freq_ - 1)); } + CacheKey& Decrement() { + freq_ = std::max(0, static_cast(freq_ - 1)); + return *this; + } CacheKey& SetFreq() { freq_ = 1; @@ -112,6 +118,18 @@ struct CacheKey { return *this; } + CacheKey& StartUse() { + ++reference_count_; + return *this; + } + + CacheKey& EndUse() { + --reference_count_; + return *this; + } + + bool InUse() { return reference_count_ > 0; } + friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) { return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", " << key_ref.position_in_cache_ << ")"; @@ -120,7 +138,8 @@ struct CacheKey { private: int64_t freq_ : 3; int64_t key_ : 61; - int64_t position_in_cache_; + int64_t position_in_cache_ : 48; + int64_t reference_count_ : 16; }; class BaseCachePolicy { @@ -129,14 +148,14 @@ class BaseCachePolicy { * @brief The policy query function. * @param keys The keys to query the cache. * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that + * @return (positions, indices, missing_keys, found_keys), where positions has + * the locations of the keys which were found in the cache, missing_keys has + * the keys that were not found and indices is defined such that * keys[indices[:positions.size(0)]] gives us the found keys and * keys[indices[positions.size(0):]] is identical to missing_keys. */ - virtual std::tuple Query( - torch::Tensor keys) = 0; + virtual std::tuple + Query(torch::Tensor keys) = 0; /** * @brief The policy replace function. @@ -147,12 +166,21 @@ class BaseCachePolicy { */ virtual torch::Tensor Replace(torch::Tensor keys) = 0; + /** + * @brief A reader has finished reading these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + virtual void ReadingCompleted(torch::Tensor keys) = 0; + template - static std::tuple QueryImpl( - CachePolicy& policy, torch::Tensor keys); + static std::tuple + QueryImpl(CachePolicy& policy, torch::Tensor keys); template static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys); + + template + static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys); }; /** @@ -171,27 +199,21 @@ class S3FifoCachePolicy : public BaseCachePolicy { S3FifoCachePolicy() = default; /** - * @brief The policy query function. - * @param keys The keys to query the cache. - * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that - * keys[indices[:positions.size(0)]] gives us the found keys and - * keys[indices[positions.size(0):]] is identical to missing_keys. + * @brief See BaseCachePolicy::Query. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** - * @brief The policy replace function. - * @param keys The keys to query the cache. - * - * @return positions tensor is returned holding the locations of the replaced - * entries in the cache. + * @brief See BaseCachePolicy::Replace. */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + friend std::ostream& operator<<( std::ostream& os, const S3FifoCachePolicy& policy) { return os << "small_queue_: " << policy.small_queue_ << "\n" @@ -204,8 +226,7 @@ class S3FifoCachePolicy : public BaseCachePolicy { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - cache_key.Increment(); - return cache_key.getPos(); + return cache_key.Increment().StartUse().getPos(); } return std::nullopt; } @@ -218,12 +239,14 @@ class S3FifoCachePolicy : public BaseCachePolicy { return pos; } + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + private: int64_t EvictMainQueue() { while (true) { auto evicted = main_queue_.Pop(); auto it = key_to_cache_key_.find(evicted.getKey()); - if (evicted.getFreq() > 0) { + if (evicted.getFreq() > 0 || evicted.InUse()) { evicted.Decrement(); it->second = main_queue_.Push(evicted); } else { @@ -238,7 +261,9 @@ class S3FifoCachePolicy : public BaseCachePolicy { size--) { auto evicted = small_queue_.Pop(); auto it = key_to_cache_key_.find(evicted.getKey()); - if (evicted.getFreq() <= 0) { + if (evicted.getFreq() > 0 || evicted.InUse()) { + it->second = main_queue_.Push(evicted.ResetFreq()); + } else { key_to_cache_key_.erase(it); const auto evicted_key = evicted.getKey(); if (ghost_queue_.IsFull()) { @@ -247,8 +272,6 @@ class S3FifoCachePolicy : public BaseCachePolicy { ghost_set_.insert(evicted_key); ghost_queue_.Push(evicted_key); return evicted.getPos(); - } else { - it->second = main_queue_.Push(evicted.ResetFreq()); } } return -1; @@ -286,33 +309,26 @@ class SieveCachePolicy : public BaseCachePolicy { SieveCachePolicy() = default; /** - * @brief The policy query function. - * @param keys The keys to query the cache. - * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that - * keys[indices[:positions.size(0)]] gives us the found keys and - * keys[indices[positions.size(0):]] is identical to missing_keys. + * @brief See BaseCachePolicy::Query. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** - * @brief The policy replace function. - * @param keys The keys to query the cache. - * - * @return positions tensor is returned holding the locations of the - * replaced entries in the cache. + * @brief See BaseCachePolicy::Replace. */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - cache_key.SetFreq(); - return cache_key.getPos(); + return cache_key.SetFreq().StartUse().getPos(); } return std::nullopt; } @@ -324,12 +340,14 @@ class SieveCachePolicy : public BaseCachePolicy { return pos; } + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + private: int64_t Evict() { // If the cache has space, get an unused slot otherwise perform eviction. if (cache_usage_ < capacity_) return cache_usage_++; --hand_; - while (hand_->getFreq()) { + while (hand_->getFreq() || hand_->InUse()) { hand_->ResetFreq(); if (hand_ == queue_.begin()) hand_ = queue_.end(); --hand_; @@ -369,33 +387,27 @@ class LruCachePolicy : public BaseCachePolicy { LruCachePolicy() = default; /** - * @brief The policy query function. - * @param keys The keys to query the cache. - * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that - * keys[indices[:positions.size(0)]] gives us the found keys and - * keys[indices[positions.size(0):]] is identical to missing_keys. + * @brief See BaseCachePolicy::Query. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** - * @brief The policy replace function. - * @param keys The keys to query the cache. - * - * @return positions tensor is returned holding the locations of the - * replaced entries in the cache. + * @brief See BaseCachePolicy::Replace. */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { - const auto cache_key = *it->second; + auto cache_key = *it->second; queue_.erase(it->second); - queue_.push_front(cache_key); + queue_.push_front(cache_key.StartUse()); it->second = queue_.begin(); return cache_key.getPos(); } @@ -409,10 +421,19 @@ class LruCachePolicy : public BaseCachePolicy { return pos; } + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + private: int64_t Evict() { // If the cache has space, get an unused slot otherwise perform eviction. if (cache_usage_ < capacity_) return cache_usage_++; + // Do not evict items that are still in use. + for (auto cache_key = queue_.back(); cache_key.InUse(); + cache_key = queue_.back()) { + queue_.pop_back(); + queue_.push_front(cache_key); + key_to_cache_key_[cache_key.getKey()] = queue_.begin(); + } const auto& cache_key = queue_.back(); TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); const auto pos = cache_key.getPos(); @@ -443,33 +464,26 @@ class ClockCachePolicy : public BaseCachePolicy { ClockCachePolicy() = default; /** - * @brief The policy query function. - * @param keys The keys to query the cache. - * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that - * keys[indices[:positions.size(0)]] gives us the found keys and - * keys[indices[positions.size(0):]] is identical to missing_keys. + * @brief See BaseCachePolicy::Query. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** - * @brief The policy replace function. - * @param keys The keys to query the cache. - * - * @return positions tensor is returned holding the locations of the - * replaced entries in the cache. + * @brief See BaseCachePolicy::Replace. */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - cache_key.SetFreq(); - return cache_key.getPos(); + return cache_key.SetFreq().StartUse().getPos(); } return std::nullopt; } @@ -480,6 +494,8 @@ class ClockCachePolicy : public BaseCachePolicy { return pos; } + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + private: int64_t Evict() { // If the cache has space, get an unused slot otherwise perform eviction. @@ -487,7 +503,7 @@ class ClockCachePolicy : public BaseCachePolicy { CacheKey cache_key; while (true) { cache_key = queue_.Pop(); - if (cache_key.getFreq()) { + if (cache_key.getFreq() || cache_key.InUse()) { key_to_cache_key_[cache_key.getKey()] = queue_.Push(cache_key.ResetFreq()); } else diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 21cb8ef4928f..37ab282a7029 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -112,15 +112,16 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) { permuted_keys}; } -std::tuple +std::tuple PartitionedCachePolicy::Query(torch::Tensor keys) { if (policies_.size() == 1) return policies_[0]->Query(keys); torch::Tensor offsets, indices, permuted_keys; std::tie(offsets, indices, permuted_keys) = Partition(keys); auto offsets_ptr = offsets.data_ptr(); auto indices_ptr = indices.data_ptr(); - std::vector> results( - policies_.size()); + std::vector< + std::tuple> + results(policies_.size()); torch::Tensor result_offsets_tensor = torch::empty(policies_.size() * 2 + 1, offsets.options()); auto result_offsets = result_offsets_tensor.data_ptr(); @@ -145,6 +146,9 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { torch::Tensor missing_keys = torch::empty( indices.size(0) - positions.size(0), std::get<2>(results[0]).options().pinned_memory(keys.is_pinned())); + torch::Tensor found_keys = torch::empty( + positions.size(0), + std::get<3>(results[0]).options().pinned_memory(keys.is_pinned())); auto output_indices_ptr = output_indices.data_ptr(); torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { if (begin == end) return; @@ -162,6 +166,11 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { selected_positions_ptr, selected_positions_ptr + num_selected, positions.data_ptr() + begin, [off = tid * capacity_ / policies_.size()](auto x) { return x + off; }); + std::memcpy( + reinterpret_cast(found_keys.data_ptr()) + + begin * found_keys.element_size(), + std::get<3>(results[tid]).data_ptr(), + num_selected * found_keys.element_size()); begin = result_offsets[policies_.size() + tid]; end = result_offsets[policies_.size() + tid + 1]; const auto num_missing = end - begin; @@ -175,7 +184,7 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { std::get<2>(results[tid]).data_ptr(), num_missing * missing_keys.element_size()); }); - return std::make_tuple(positions, output_indices, missing_keys); + return std::make_tuple(positions, output_indices, missing_keys, found_keys); } torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) { @@ -204,6 +213,23 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) { return output_positions; } +void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) { + if (policies_.size() == 1) { + policies_[0]->ReadingCompleted(keys); + return; + } + torch::Tensor offsets, indices, permuted_keys; + std::tie(offsets, indices, permuted_keys) = Partition(keys); + auto offsets_ptr = offsets.data_ptr(); + torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { + if (begin == end) return; + const auto tid = begin; + begin = offsets_ptr[tid]; + end = offsets_ptr[tid + 1]; + policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end)); + }); +} + template c10::intrusive_ptr PartitionedCachePolicy::Create( int64_t capacity, int64_t num_partitions) { diff --git a/graphbolt/src/partitioned_cache_policy.h b/graphbolt/src/partitioned_cache_policy.h index c3d42b3f82bc..0d343a87438e 100644 --- a/graphbolt/src/partitioned_cache_policy.h +++ b/graphbolt/src/partitioned_cache_policy.h @@ -62,7 +62,7 @@ class PartitionedCachePolicy : public BaseCachePolicy, * keys[indices[:positions.size(0)]] gives us the found keys and * keys[indices[positions.size(0):]] is identical to missing_keys. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** @@ -74,6 +74,12 @@ class PartitionedCachePolicy : public BaseCachePolicy, */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief A reader has finished reading these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + void ReadingCompleted(torch::Tensor keys); + template static c10::intrusive_ptr Create( int64_t capacity, int64_t num_partitions); diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 6cbe7df6a717..d3c3789abb1b 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -99,7 +99,10 @@ TORCH_LIBRARY(graphbolt, m) { m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create); m.class_("PartitionedCachePolicy") .def("query", &storage::PartitionedCachePolicy::Query) - .def("replace", &storage::PartitionedCachePolicy::Replace); + .def("replace", &storage::PartitionedCachePolicy::Replace) + .def( + "reading_completed", + &storage::PartitionedCachePolicy::ReadingCompleted); m.def( "s3_fifo_cache_policy", &storage::PartitionedCachePolicy::Create); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 773f1a28256f..740dc223edfc 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -53,8 +53,9 @@ def query(self, keys): pinned, then the returned values tensor is pinned as well. """ self.total_queries += keys.shape[0] - positions, index, missing_keys = self._policy.query(keys) + positions, index, missing_keys, found_keys = self._policy.query(keys) values = self._cache.query(positions, index, keys.shape[0]) + self._policy.reading_completed(found_keys) self.total_miss += missing_keys.shape[0] missing_index = index[positions.size(0) :] return values, missing_index, missing_keys From d73e3381fdcf22c59a8994b8b87e9da6f8a641f7 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Tue, 9 Jul 2024 01:06:49 -0400 Subject: [PATCH 12/12] make the derived member implementations protected. --- graphbolt/src/cache_policy.h | 1 + 1 file changed, 1 insertion(+) diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 1140a14420f5..8ce1bf6df6d4 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -172,6 +172,7 @@ class BaseCachePolicy { */ virtual void ReadingCompleted(torch::Tensor keys) = 0; + protected: template static std::tuple QueryImpl(CachePolicy& policy, torch::Tensor keys);