Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Refactor S3-FIFO and add SIEVE, LRU and CLOCK. #7508

Merged
merged 12 commits into from
Jul 9, 2024
Prev Previous commit
Next Next commit
Mark items currently being read as unevictable.
  • Loading branch information
mfbalin committed Jul 9, 2024
commit c312794b3593cd58a2ced2414bec0ba8581808ec
57 changes: 43 additions & 14 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,40 @@ namespace graphbolt {
namespace storage {

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto indices = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto missing_keys =
auto filtered_keys =
torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned()));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Query::DispatchForKeys", ([&] {
keys.scalar_type(), "BaseCachePolicy::Query::DispatchForKeys", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto pos = policy.Read(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
missing_keys_ptr[missing_cnt] = key;
filtered_keys_ptr[missing_cnt] = key;
}
}
}));
return {
positions.slice(0, 0, found_cnt), indices,
missing_keys.slice(0, found_cnt)};
filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)};
}

template <typename CachePolicy>
Expand All @@ -65,7 +66,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] {
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
phmap::flat_hash_set<int64_t> position_set;
Expand All @@ -83,6 +84,18 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
return positions;
}

template <typename CachePolicy>
void BaseCachePolicy::ReadingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
policy.Unmark(keys_ptr[i]);
}
}));
}

S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
: small_queue_(capacity),
main_queue_(capacity),
Expand All @@ -95,7 +108,7 @@ S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
Expand All @@ -104,50 +117,66 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
: hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SieveCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

LruCachePolicy::LruCachePolicy(int64_t capacity)
: capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> LruCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

ClockCachePolicy::ClockCachePolicy(int64_t capacity)
: queue_(capacity), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> ClockCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

} // namespace storage
} // namespace graphbolt
Loading
Loading