diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index 0c9908378931..94324d62e2c8 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -22,79 +22,161 @@ namespace graphbolt { namespace storage { -S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) - : small_queue_(capacity / 10), - main_queue_(capacity - capacity / 10), - ghost_queue_time_(0), - capacity_(capacity), - cache_usage_(0) {} - -std::tuple -S3FifoCachePolicy::Query(torch::Tensor keys) { - auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); - auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); - auto missing_keys = torch::empty_like(keys); +template +std::tuple +BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { + auto positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); + auto indices = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); + auto filtered_keys = + torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned())); int64_t found_cnt = 0; int64_t missing_cnt = keys.size(0); AT_DISPATCH_INDEX_TYPES( - keys.scalar_type(), "S3FifoCachePolicy::Query::DispatchForKeys", ([&] { + keys.scalar_type(), "BaseCachePolicy::Query::DispatchForKeys", ([&] { auto keys_ptr = keys.data_ptr(); auto positions_ptr = positions.data_ptr(); auto indices_ptr = indices.data_ptr(); - auto missing_keys_ptr = missing_keys.data_ptr(); + auto filtered_keys_ptr = filtered_keys.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - auto it = key_to_cache_key_.find(key); - if (it != key_to_cache_key_.end()) { - auto& cache_key = *it->second; - cache_key.Increment(); - positions_ptr[found_cnt] = cache_key.getPos(); + auto pos = policy.Read(key); + if (pos.has_value()) { + positions_ptr[found_cnt] = *pos; + filtered_keys_ptr[found_cnt] = key; indices_ptr[found_cnt++] = i; } else { indices_ptr[--missing_cnt] = i; - missing_keys_ptr[missing_cnt] = key; + filtered_keys_ptr[missing_cnt] = key; } } })); return { positions.slice(0, 0, found_cnt), indices, - missing_keys.slice(0, found_cnt)}; + filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)}; } -torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { - auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); +template +torch::Tensor BaseCachePolicy::ReplaceImpl( + CachePolicy& policy, torch::Tensor keys) { + auto positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); AT_DISPATCH_INDEX_TYPES( - keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] { + keys.scalar_type(), "BaseCachePolicy::Replace", ([&] { auto keys_ptr = keys.data_ptr(); auto positions_ptr = positions.data_ptr(); + phmap::flat_hash_set position_set; + position_set.reserve(keys.size(0)); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - auto it = key_to_cache_key_.find(key); - if (it != - key_to_cache_key_.end()) { // Already in the cache, inc freq. - auto& cache_key = *it->second; - cache_key.Increment(); - positions_ptr[i] = cache_key.getPos(); - } else { - const auto in_ghost_queue = InGhostQueue(key); - auto& queue = in_ghost_queue ? main_queue_ : small_queue_; - int64_t pos; - if (queue.IsFull()) { - // When the queue is full, we need to make a space by evicting. - // Inside ghost queue means insertion into M, otherwise S. - pos = (in_ghost_queue ? EvictMainQueue() : EvictSmallQueue()); - } else { // If the cache is not full yet, get an unused empty slot. - pos = cache_usage_++; - } - TORCH_CHECK(0 <= pos && pos < capacity_, "Position out of bounds!"); - key_to_cache_key_[key] = queue.Push(CacheKey(key, pos)); - positions_ptr[i] = pos; - } + const auto pos_optional = policy.Read(key); + const auto pos = pos_optional ? *pos_optional : policy.Insert(key); + positions_ptr[i] = pos; + TORCH_CHECK( + std::get<1>(position_set.insert(pos)), + "Can't insert all, larger cache capacity is needed."); } })); - TrimGhostQueue(); return positions; } +template +void BaseCachePolicy::ReadingCompletedImpl( + CachePolicy& policy, torch::Tensor keys) { + AT_DISPATCH_INDEX_TYPES( + keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] { + auto keys_ptr = keys.data_ptr(); + for (int64_t i = 0; i < keys.size(0); i++) { + policy.Unmark(keys_ptr[i]); + } + })); +} + +S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) + : small_queue_(capacity), + main_queue_(capacity), + ghost_queue_(capacity - capacity / 10), + capacity_(capacity), + cache_usage_(0), + small_queue_size_target_(capacity / 10) { + TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough."); + ghost_set_.reserve(ghost_queue_.Capacity()); + key_to_cache_key_.reserve(capacity); +} + +std::tuple +S3FifoCachePolicy::Query(torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + +void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + +SieveCachePolicy::SieveCachePolicy(int64_t capacity) + : hand_(queue_.end()), capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple +SieveCachePolicy::Query(torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + +void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + +LruCachePolicy::LruCachePolicy(int64_t capacity) + : capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple +LruCachePolicy::Query(torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + +void LruCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + +ClockCachePolicy::ClockCachePolicy(int64_t capacity) + : queue_(capacity), capacity_(capacity), cache_usage_(0) { + TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); + key_to_cache_key_.reserve(capacity); +} + +std::tuple +ClockCachePolicy::Query(torch::Tensor keys) { + return QueryImpl(*this, keys); +} + +torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) { + return ReplaceImpl(*this, keys); +} + +void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingCompletedImpl(*this, keys); +} + } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index c07f89ce97ac..8ce1bf6df6d4 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -53,6 +53,12 @@ struct CircularQueue { return diff == 0 || diff == capacity_; } + auto Size() const { + auto diff = tail_ - head_; + if (diff < 0) diff += capacity_; + return diff; + } + friend std::ostream& operator<<( std::ostream& os, const CircularQueue& queue) { for (auto i = queue.head_; i != queue.tail_; queue.PostIncrement(i)) { @@ -63,7 +69,7 @@ struct CircularQueue { bool IsEmpty() const { return tail_ == head_; } - int64_t Capacity() const { return capacity_ - 1; } + auto Capacity() const { return capacity_ - 1; } private: int64_t PostIncrement(int64_t& i) const { @@ -80,7 +86,7 @@ struct CircularQueue { struct CacheKey { CacheKey(int64_t key, int64_t position) - : freq_(0), key_(key), position_in_cache_(position) { + : freq_(0), key_(key), position_in_cache_(position), reference_count_(0) { static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t)); } @@ -92,9 +98,37 @@ struct CacheKey { auto getPos() const { return position_in_cache_; } - void Increment() { freq_ = std::min(3, static_cast(freq_ + 1)); } + CacheKey& Increment() { + freq_ = std::min(3, static_cast(freq_ + 1)); + return *this; + } + + CacheKey& Decrement() { + freq_ = std::max(0, static_cast(freq_ - 1)); + return *this; + } + + CacheKey& SetFreq() { + freq_ = 1; + return *this; + } - void Decrement() { freq_ = std::max(0, static_cast(freq_ - 1)); } + CacheKey& ResetFreq() { + freq_ = 0; + return *this; + } + + CacheKey& StartUse() { + ++reference_count_; + return *this; + } + + CacheKey& EndUse() { + --reference_count_; + return *this; + } + + bool InUse() { return reference_count_ > 0; } friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) { return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", " @@ -104,7 +138,8 @@ struct CacheKey { private: int64_t freq_ : 3; int64_t key_ : 61; - int64_t position_in_cache_; + int64_t position_in_cache_ : 48; + int64_t reference_count_ : 16; }; class BaseCachePolicy { @@ -113,14 +148,14 @@ class BaseCachePolicy { * @brief The policy query function. * @param keys The keys to query the cache. * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that + * @return (positions, indices, missing_keys, found_keys), where positions has + * the locations of the keys which were found in the cache, missing_keys has + * the keys that were not found and indices is defined such that * keys[indices[:positions.size(0)]] gives us the found keys and * keys[indices[positions.size(0):]] is identical to missing_keys. */ - virtual std::tuple Query( - torch::Tensor keys) = 0; + virtual std::tuple + Query(torch::Tensor keys) = 0; /** * @brief The policy replace function. @@ -130,6 +165,23 @@ class BaseCachePolicy { * entries in the cache. */ virtual torch::Tensor Replace(torch::Tensor keys) = 0; + + /** + * @brief A reader has finished reading these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + virtual void ReadingCompleted(torch::Tensor keys) = 0; + + protected: + template + static std::tuple + QueryImpl(CachePolicy& policy, torch::Tensor keys); + + template + static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys); + + template + static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys); }; /** @@ -148,42 +200,54 @@ class S3FifoCachePolicy : public BaseCachePolicy { S3FifoCachePolicy() = default; /** - * @brief The policy query function. - * @param keys The keys to query the cache. - * - * @return (positions, indices, missing_keys), where positions has the - * locations of the keys which were found in the cache, missing_keys has the - * keys that were not found and indices is defined such that - * keys[indices[:positions.size(0)]] gives us the found keys and - * keys[indices[positions.size(0):]] is identical to missing_keys. + * @brief See BaseCachePolicy::Query. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** - * @brief The policy replace function. - * @param keys The keys to query the cache. - * - * @return positions tensor is returned holding the locations of the replaced - * entries in the cache. + * @brief See BaseCachePolicy::Replace. */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + friend std::ostream& operator<<( std::ostream& os, const S3FifoCachePolicy& policy) { return os << "small_queue_: " << policy.small_queue_ << "\n" << "main_queue_: " << policy.main_queue_ << "\n" << "cache_usage_: " << policy.cache_usage_ << "\n" - << "ghost_queue_time_: " << policy.ghost_queue_time_ << "\n" << "capacity_: " << policy.capacity_ << "\n"; } + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + return cache_key.Increment().StartUse().getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + const auto in_ghost_queue = ghost_set_.erase(key); + auto& queue = in_ghost_queue ? main_queue_ : small_queue_; + key_to_cache_key_[key] = queue.Push(CacheKey(key, pos)); + return pos; + } + + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + private: int64_t EvictMainQueue() { while (true) { auto evicted = main_queue_.Pop(); auto it = key_to_cache_key_.find(evicted.getKey()); - if (evicted.getFreq() > 0) { + if (evicted.getFreq() > 0 || evicted.InUse()) { evicted.Decrement(); it->second = main_queue_.Push(evicted); } else { @@ -194,47 +258,266 @@ class S3FifoCachePolicy : public BaseCachePolicy { } int64_t EvictSmallQueue() { - auto evicted = small_queue_.Pop(); - const auto is_M_full = main_queue_.IsFull(); - auto it = key_to_cache_key_.find(evicted.getKey()); - if (evicted.getFreq() <= 0 && is_M_full) { - key_to_cache_key_.erase(it); - // No overflow is expected for any GNN workload. - ghost_map_[evicted.getKey()] = ghost_queue_time_++; - return evicted.getPos(); + for (auto size = small_queue_.Size(); size > small_queue_size_target_; + size--) { + auto evicted = small_queue_.Pop(); + auto it = key_to_cache_key_.find(evicted.getKey()); + if (evicted.getFreq() > 0 || evicted.InUse()) { + it->second = main_queue_.Push(evicted.ResetFreq()); + } else { + key_to_cache_key_.erase(it); + const auto evicted_key = evicted.getKey(); + if (ghost_queue_.IsFull()) { + ghost_set_.erase(ghost_queue_.Pop()); + } + ghost_set_.insert(evicted_key); + ghost_queue_.Push(evicted_key); + return evicted.getPos(); + } + } + return -1; + } + + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + const auto pos = EvictSmallQueue(); + return pos >= 0 ? pos : EvictMainQueue(); + } + + CircularQueue small_queue_, main_queue_; + CircularQueue ghost_queue_; + int64_t capacity_; + int64_t cache_usage_; + int64_t small_queue_size_target_; + phmap::flat_hash_set ghost_set_; + phmap::flat_hash_map key_to_cache_key_; +}; + +/** + * @brief SIEVE is a simple, scalable FIFObased algorithm with a single static + * queue. https://www.usenix.org/system/files/nsdi24-zhang-yazhuo.pdf + **/ +class SieveCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the SieveCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + SieveCachePolicy(int64_t capacity); + + SieveCachePolicy() = default; + + /** + * @brief See BaseCachePolicy::Query. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::Replace. + */ + torch::Tensor Replace(torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + return cache_key.SetFreq().StartUse().getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + queue_.push_front(CacheKey(key, pos)); + key_to_cache_key_[key] = queue_.begin(); + return pos; + } + + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + --hand_; + while (hand_->getFreq() || hand_->InUse()) { + hand_->ResetFreq(); + if (hand_ == queue_.begin()) hand_ = queue_.end(); + --hand_; + } + TORCH_CHECK(key_to_cache_key_.erase(hand_->getKey())); + const auto pos = hand_->getPos(); + const auto temp = hand_; + if (hand_ == queue_.begin()) { + hand_ = queue_.end(); } else { - const auto pos = is_M_full ? EvictMainQueue() : cache_usage_++; - it->second = main_queue_.Push(evicted); - return pos; + ++hand_; } + queue_.erase(temp); + return pos; } - // Is inside the ghost queue. - bool InGhostQueue(int64_t key) const { - auto it = ghost_map_.find(key); - return it != ghost_map_.end() && - ghost_queue_time_ - it->second <= main_queue_.Capacity(); + std::list queue_; + decltype(queue_)::iterator hand_; + int64_t capacity_; + int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; +}; + +/** + * @brief LeastRecentlyUsed is a simple, scalable FIFObased algorithm with a + * single static queue. + **/ +class LruCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the LruCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + LruCachePolicy(int64_t capacity); + + LruCachePolicy() = default; + + /** + * @brief See BaseCachePolicy::Query. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::Replace. + */ + torch::Tensor Replace(torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto cache_key = *it->second; + queue_.erase(it->second); + queue_.push_front(cache_key.StartUse()); + it->second = queue_.begin(); + return cache_key.getPos(); + } + return std::nullopt; } - void TrimGhostQueue() { - if (static_cast(ghost_map_.size()) >= 2 * main_queue_.Capacity()) { - // Here, we ensure that the ghost_map_ does not grow too much. - phmap::priv::erase_if(ghost_map_, [&](const auto& key_value) { - const auto timestamp = key_value.second; - return ghost_queue_time_ - timestamp > main_queue_.Capacity(); - }); + int64_t Insert(int64_t key) { + const auto pos = Evict(); + queue_.push_front(CacheKey(key, pos)); + key_to_cache_key_[key] = queue_.begin(); + return pos; + } + + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + // Do not evict items that are still in use. + for (auto cache_key = queue_.back(); cache_key.InUse(); + cache_key = queue_.back()) { + queue_.pop_back(); + queue_.push_front(cache_key); + key_to_cache_key_[cache_key.getKey()] = queue_.begin(); } + const auto& cache_key = queue_.back(); + TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); + const auto pos = cache_key.getPos(); + queue_.pop_back(); + return pos; } - CircularQueue small_queue_, main_queue_; - phmap::flat_hash_map key_to_cache_key_; - phmap::flat_hash_map ghost_map_; - // Keeps track of the number of insertions into the ghost queue so far. If an - // item's time of insertion is older than main_queue_.Capacity(), then it is - // not really considered in the ghost queue. - int64_t ghost_queue_time_; + std::list queue_; int64_t capacity_; int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; +}; + +/** + * @brief Clock (FIFO-Reinsertion) is a simple, scalable FIFObased algorithm + * with a single static queue. + * https://people.csail.mit.edu/saltzer/Multics/MHP-Saltzer-060508/bookcases/M00s/M0104%20074-12).PDF + **/ +class ClockCachePolicy : public BaseCachePolicy { + public: + /** + * @brief Constructor for the ClockCachePolicy class. + * + * @param capacity The capacity of the cache in terms of # elements. + */ + ClockCachePolicy(int64_t capacity); + + ClockCachePolicy() = default; + + /** + * @brief See BaseCachePolicy::Query. + */ + std::tuple Query( + torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::Replace. + */ + torch::Tensor Replace(torch::Tensor keys); + + /** + * @brief See BaseCachePolicy::ReadingCompleted. + */ + void ReadingCompleted(torch::Tensor keys); + + std::optional Read(int64_t key) { + auto it = key_to_cache_key_.find(key); + if (it != key_to_cache_key_.end()) { + auto& cache_key = *it->second; + return cache_key.SetFreq().StartUse().getPos(); + } + return std::nullopt; + } + + int64_t Insert(int64_t key) { + const auto pos = Evict(); + key_to_cache_key_[key] = queue_.Push(CacheKey(key, pos)); + return pos; + } + + void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + + private: + int64_t Evict() { + // If the cache has space, get an unused slot otherwise perform eviction. + if (cache_usage_ < capacity_) return cache_usage_++; + CacheKey cache_key; + while (true) { + cache_key = queue_.Pop(); + if (cache_key.getFreq() || cache_key.InUse()) { + key_to_cache_key_[cache_key.getKey()] = + queue_.Push(cache_key.ResetFreq()); + } else + break; + } + TORCH_CHECK(key_to_cache_key_.erase(cache_key.getKey())); + return cache_key.getPos(); + } + + CircularQueue queue_; + int64_t capacity_; + int64_t cache_usage_; + phmap::flat_hash_map key_to_cache_key_; }; } // namespace storage diff --git a/graphbolt/src/feature_cache.cc b/graphbolt/src/feature_cache.cc index ffe87f8bd377..c532a0efe79d 100644 --- a/graphbolt/src/feature_cache.cc +++ b/graphbolt/src/feature_cache.cc @@ -29,8 +29,8 @@ FeatureCache::FeatureCache( : tensor_(torch::empty(shape, c10::TensorOptions().dtype(dtype))) {} torch::Tensor FeatureCache::Query( - torch::Tensor positions, torch::Tensor indices, int64_t size, - bool pin_memory) { + torch::Tensor positions, torch::Tensor indices, int64_t size) { + const bool pin_memory = positions.is_pinned() || indices.is_pinned(); std::vector output_shape{ tensor_.sizes().begin(), tensor_.sizes().end()}; output_shape[0] = size; @@ -53,10 +53,6 @@ torch::Tensor FeatureCache::Query( } void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) { - if (positions.size(0) > tensor_.size(0)) { - positions = positions.slice(0, 0, tensor_.size(0)); - values = values.slice(0, 0, tensor_.size(0)); - } const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size(); auto values_ptr = reinterpret_cast(values.data_ptr()); const auto tensor_ptr = reinterpret_cast(tensor_.data_ptr()); diff --git a/graphbolt/src/feature_cache.h b/graphbolt/src/feature_cache.h index c97b3ef7848e..880e7bb37cdf 100644 --- a/graphbolt/src/feature_cache.h +++ b/graphbolt/src/feature_cache.h @@ -55,8 +55,7 @@ struct FeatureCache : public torch::CustomClassHolder { * is true. */ torch::Tensor Query( - torch::Tensor positions, torch::Tensor indices, int64_t size, - bool pin_memory); + torch::Tensor positions, torch::Tensor indices, int64_t size); /** * @brief The cache replace function. diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 04e735341324..37ab282a7029 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -112,15 +112,16 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) { permuted_keys}; } -std::tuple +std::tuple PartitionedCachePolicy::Query(torch::Tensor keys) { if (policies_.size() == 1) return policies_[0]->Query(keys); torch::Tensor offsets, indices, permuted_keys; std::tie(offsets, indices, permuted_keys) = Partition(keys); auto offsets_ptr = offsets.data_ptr(); auto indices_ptr = indices.data_ptr(); - std::vector> results( - policies_.size()); + std::vector< + std::tuple> + results(policies_.size()); torch::Tensor result_offsets_tensor = torch::empty(policies_.size() * 2 + 1, offsets.options()); auto result_offsets = result_offsets_tensor.data_ptr(); @@ -138,10 +139,16 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { result_offsets, result_offsets + result_offsets_tensor.size(0), result_offsets, 0); torch::Tensor positions = torch::empty( - result_offsets[policies_.size()], std::get<0>(results[0]).options()); - torch::Tensor output_indices = torch::empty_like(indices); + result_offsets[policies_.size()], + std::get<0>(results[0]).options().pinned_memory(keys.is_pinned())); + torch::Tensor output_indices = torch::empty_like( + indices, indices.options().pinned_memory(keys.is_pinned())); torch::Tensor missing_keys = torch::empty( - indices.size(0) - positions.size(0), std::get<2>(results[0]).options()); + indices.size(0) - positions.size(0), + std::get<2>(results[0]).options().pinned_memory(keys.is_pinned())); + torch::Tensor found_keys = torch::empty( + positions.size(0), + std::get<3>(results[0]).options().pinned_memory(keys.is_pinned())); auto output_indices_ptr = output_indices.data_ptr(); torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { if (begin == end) return; @@ -159,6 +166,11 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { selected_positions_ptr, selected_positions_ptr + num_selected, positions.data_ptr() + begin, [off = tid * capacity_ / policies_.size()](auto x) { return x + off; }); + std::memcpy( + reinterpret_cast(found_keys.data_ptr()) + + begin * found_keys.element_size(), + std::get<3>(results[tid]).data_ptr(), + num_selected * found_keys.element_size()); begin = result_offsets[policies_.size() + tid]; end = result_offsets[policies_.size() + tid + 1]; const auto num_missing = end - begin; @@ -172,15 +184,16 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { std::get<2>(results[tid]).data_ptr(), num_missing * missing_keys.element_size()); }); - return std::make_tuple(positions, output_indices, missing_keys); + return std::make_tuple(positions, output_indices, missing_keys, found_keys); } torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) { if (policies_.size() == 1) return policies_[0]->Replace(keys); torch::Tensor offsets, indices, permuted_keys; std::tie(offsets, indices, permuted_keys) = Partition(keys); - auto output_positions = - torch::empty_like(keys, keys.options().dtype(torch::kInt64)); + auto output_positions = torch::empty_like( + keys, + keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned())); auto offsets_ptr = offsets.data_ptr(); auto indices_ptr = indices.data_ptr(); auto output_positions_ptr = output_positions.data_ptr(); @@ -200,6 +213,23 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) { return output_positions; } +void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) { + if (policies_.size() == 1) { + policies_[0]->ReadingCompleted(keys); + return; + } + torch::Tensor offsets, indices, permuted_keys; + std::tie(offsets, indices, permuted_keys) = Partition(keys); + auto offsets_ptr = offsets.data_ptr(); + torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { + if (begin == end) return; + const auto tid = begin; + begin = offsets_ptr[tid]; + end = offsets_ptr[tid + 1]; + policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end)); + }); +} + template c10::intrusive_ptr PartitionedCachePolicy::Create( int64_t capacity, int64_t num_partitions) { @@ -210,6 +240,12 @@ c10::intrusive_ptr PartitionedCachePolicy::Create( template c10::intrusive_ptr PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); +template c10::intrusive_ptr + PartitionedCachePolicy::Create(int64_t, int64_t); } // namespace storage } // namespace graphbolt diff --git a/graphbolt/src/partitioned_cache_policy.h b/graphbolt/src/partitioned_cache_policy.h index c3d42b3f82bc..0d343a87438e 100644 --- a/graphbolt/src/partitioned_cache_policy.h +++ b/graphbolt/src/partitioned_cache_policy.h @@ -62,7 +62,7 @@ class PartitionedCachePolicy : public BaseCachePolicy, * keys[indices[:positions.size(0)]] gives us the found keys and * keys[indices[positions.size(0):]] is identical to missing_keys. */ - std::tuple Query( + std::tuple Query( torch::Tensor keys); /** @@ -74,6 +74,12 @@ class PartitionedCachePolicy : public BaseCachePolicy, */ torch::Tensor Replace(torch::Tensor keys); + /** + * @brief A reader has finished reading these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + void ReadingCompleted(torch::Tensor keys); + template static c10::intrusive_ptr Create( int64_t capacity, int64_t num_partitions); diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index d1d25d771e90..d3c3789abb1b 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -99,10 +99,22 @@ TORCH_LIBRARY(graphbolt, m) { m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create); m.class_("PartitionedCachePolicy") .def("query", &storage::PartitionedCachePolicy::Query) - .def("replace", &storage::PartitionedCachePolicy::Replace); + .def("replace", &storage::PartitionedCachePolicy::Replace) + .def( + "reading_completed", + &storage::PartitionedCachePolicy::ReadingCompleted); m.def( "s3_fifo_cache_policy", &storage::PartitionedCachePolicy::Create); + m.def( + "sieve_cache_policy", + &storage::PartitionedCachePolicy::Create); + m.def( + "lru_cache_policy", + &storage::PartitionedCachePolicy::Create); + m.def( + "clock_cache_policy", + &storage::PartitionedCachePolicy::Create); m.class_("FeatureCache") .def("query", &storage::FeatureCache::Query) .def("replace", &storage::FeatureCache::Replace); diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 8922a8a27ebe..740dc223edfc 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -3,7 +3,12 @@ __all__ = ["FeatureCache"] -caching_policies = {"s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy} +caching_policies = { + "s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy, + "sieve": torch.ops.graphbolt.sieve_cache_policy, + "lru": torch.ops.graphbolt.lru_cache_policy, + "clock": torch.ops.graphbolt.clock_cache_policy, +} class FeatureCache(object): @@ -18,10 +23,11 @@ class FeatureCache(object): num_parts: int, optional The number of cache partitions for parallelism. Default is 1. policy: str, optional - The cache policy to be used. Default is "s3-fifo". + The cache policy. Default is "sieve". "s3-fifo", "lru" and "clock" are + also available. """ - def __init__(self, cache_shape, dtype, num_parts=1, policy="s3-fifo"): + def __init__(self, cache_shape, dtype, num_parts=1, policy="sieve"): assert ( policy in caching_policies ), f"{list(caching_policies.keys())} are the available caching policies." @@ -30,26 +36,26 @@ def __init__(self, cache_shape, dtype, num_parts=1, policy="s3-fifo"): self.total_miss = 0 self.total_queries = 0 - def query(self, keys, pin_memory=False): + def query(self, keys): """Queries the cache. Parameters ---------- keys : Tensor The keys to query the cache with. - pin_memory : bool, optional - Whether the output values tensor should be pinned. Default is False. Returns ------- tuple(Tensor, Tensor, Tensor) A tuple containing (values, missing_indices, missing_keys) where values[missing_indices] corresponds to cache misses that should be - filled by quering another source with missing_keys. + filled by quering another source with missing_keys. If keys is + pinned, then the returned values tensor is pinned as well. """ self.total_queries += keys.shape[0] - positions, index, missing_keys = self._policy.query(keys) - values = self._cache.query(positions, index, keys.shape[0], pin_memory) + positions, index, missing_keys, found_keys = self._policy.query(keys) + values = self._cache.query(positions, index, keys.shape[0]) + self._policy.reading_completed(found_keys) self.total_miss += missing_keys.shape[0] missing_index = index[positions.size(0) :] return values, missing_index, missing_keys diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index bd5d376449d9..ee0f5d55aa6c 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -23,15 +23,16 @@ ) @pytest.mark.parametrize("feature_size", [2, 16]) @pytest.mark.parametrize("num_parts", [1, 2]) -def test_feature_cache(dtype, feature_size, num_parts): +@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"]) +def test_feature_cache(dtype, feature_size, num_parts, policy): cache_size = 32 * num_parts a = torch.randint(0, 2, [1024, feature_size], dtype=dtype) cache = gb.impl.FeatureCache( - (cache_size,) + a.shape[1:], a.dtype, num_parts + (cache_size,) + a.shape[1:], a.dtype, num_parts, policy ) keys = torch.tensor([0, 1]) - values, missing_index, missing_keys = cache.query(keys, False) + values, missing_index, missing_keys = cache.query(keys) assert torch.equal( missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], keys, @@ -44,8 +45,8 @@ def test_feature_cache(dtype, feature_size, num_parts): pin_memory = F._default_context_str == "gpu" - keys = torch.arange(1, 33) - values, missing_index, missing_keys = cache.query(keys, pin_memory) + keys = torch.arange(1, 33, pin_memory=pin_memory) + values, missing_index, missing_keys = cache.query(keys) assert torch.equal( missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0], torch.arange(2, 33), @@ -58,9 +59,7 @@ def test_feature_cache(dtype, feature_size, num_parts): assert torch.equal(values, a[keys]) values, missing_index, missing_keys = cache.query(keys) - assert torch.equal( - missing_keys.flip([0]), torch.tensor([5] if num_parts == 1 else []) - ) + assert torch.equal(missing_keys.flip([0]), torch.tensor([])) missing_values = a[missing_keys] cache.replace(missing_keys, missing_values)