From a2c5472a742f1d1a0eed971f9b3f29b46aed8133 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Mon, 18 Mar 2024 03:08:31 -0400 Subject: [PATCH] [GraphBolt] Labor dependent template specialization. (#7220) --- graphbolt/include/graphbolt/continuous_seed.h | 25 ++++++ .../graphbolt/fused_csc_sampling_graph.h | 31 ++++--- graphbolt/src/fused_csc_sampling_graph.cc | 80 +++++++++++-------- 3 files changed, 92 insertions(+), 44 deletions(-) diff --git a/graphbolt/include/graphbolt/continuous_seed.h b/graphbolt/include/graphbolt/continuous_seed.h index c659b1753cf4..cf31618b1466 100644 --- a/graphbolt/include/graphbolt/continuous_seed.h +++ b/graphbolt/include/graphbolt/continuous_seed.h @@ -92,6 +92,31 @@ class continuous_seed { #endif // __CUDA_ARCH__ }; +class single_seed { + uint64_t seed_; + + public: + /* implicit */ single_seed(const int64_t seed) : seed_(seed) {} // NOLINT + + single_seed(torch::Tensor seed_arr) + : seed_(seed_arr.data_ptr()[0]) {} + +#ifdef __CUDACC__ + __device__ inline float uniform(const uint64_t id) const { + const uint64_t kCurandSeed = 999961; // Could be any random number. + curandStatePhilox4_32_10_t rng; + curand_init(kCurandSeed, seed_, id, &rng); + return curand_uniform(&rng); + } +#else + inline float uniform(const uint64_t id) const { + pcg32 ng0(seed_, id); + std::uniform_real_distribution uni; + return uni(ng0); + } +#endif // __CUDA_ARCH__ +}; + } // namespace graphbolt #endif // GRAPHBOLT_CONTINUOUS_SEED_H_ diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index c7777aa65131..1fd4a1066863 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -17,7 +17,11 @@ namespace graphbolt { namespace sampling { -enum SamplerType { NEIGHBOR, LABOR }; +enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT }; + +constexpr bool is_labor(SamplerType S) { + return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT; +} template struct SamplerArgs; @@ -27,6 +31,13 @@ struct SamplerArgs {}; template <> struct SamplerArgs { + const torch::Tensor& indices; + single_seed random_seed; + int64_t num_nodes; +}; + +template <> +struct SamplerArgs { const torch::Tensor& indices; continuous_seed random_seed; int64_t num_nodes; @@ -555,12 +566,12 @@ int64_t Pick( const torch::optional& probs_or_mask, SamplerArgs args, PickedType* picked_data_ptr); -template -int64_t Pick( +template +std::enable_if_t Pick( int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr); + const torch::optional& probs_or_mask, SamplerArgs args, + PickedType* picked_data_ptr); template int64_t TemporalPick( @@ -619,13 +630,13 @@ int64_t TemporalPickByEtype( PickedType* picked_data_ptr); template < - bool NonUniform, bool Replace, typename ProbsType, typename PickedType, - int StackSize = 1024> -int64_t LaborPick( + bool NonUniform, bool Replace, typename ProbsType, SamplerType S, + typename PickedType, int StackSize = 1024> +std::enable_if_t LaborPick( int64_t offset, int64_t num_neighbors, int64_t fanout, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr); + const torch::optional& probs_or_mask, SamplerArgs args, + PickedType* picked_data_ptr); } // namespace sampling } // namespace graphbolt diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index d2ce58c55f97..48d891d40b12 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "./macro.h" @@ -660,26 +661,37 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( } if (layer) { - SamplerArgs args = [&] { - if (random_seed.has_value()) { - return SamplerArgs{ - indices_, - {random_seed.value(), static_cast(seed2_contribution)}, - NumNodes()}; - } else { - return SamplerArgs{ - indices_, - RandomEngine::ThreadLocal()->RandInt( - static_cast(0), std::numeric_limits::max()), - NumNodes()}; - } - }(); - return SampleNeighborsImpl( - nodes.value(), return_eids, - GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), - GetPickFn( - fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, - args)); + if (random_seed.has_value() && random_seed->numel() >= 2) { + SamplerArgs args{ + indices_, + {random_seed.value(), static_cast(seed2_contribution)}, + NumNodes()}; + return SampleNeighborsImpl( + nodes.value(), return_eids, + GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), + GetPickFn( + fanouts, replace, indptr_.options(), type_per_edge_, + probs_or_mask, args)); + } else { + auto args = [&] { + if (random_seed.has_value() && random_seed->numel() == 1) { + return SamplerArgs{ + indices_, random_seed.value(), NumNodes()}; + } else { + return SamplerArgs{ + indices_, + RandomEngine::ThreadLocal()->RandInt( + static_cast(0), std::numeric_limits::max()), + NumNodes()}; + } + }(); + return SampleNeighborsImpl( + nodes.value(), return_eids, + GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), + GetPickFn( + fanouts, replace, indptr_.options(), type_per_edge_, + probs_or_mask, args)); + } } else { SamplerArgs args; return SampleNeighborsImpl( @@ -1297,7 +1309,7 @@ int64_t TemporalPick( } return picked_indices.numel(); } - if constexpr (S == SamplerType::LABOR) { + if constexpr (is_labor(S)) { return Pick( offset, num_neighbors, fanout, replace, options, masked_prob, args, picked_data_ptr); @@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype( return pick_offset; } -template -int64_t Pick( +template +std::enable_if_t Pick( int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr) { + const torch::optional& probs_or_mask, SamplerArgs args, + PickedType* picked_data_ptr) { if (fanout == 0) return 0; if (probs_or_mask.has_value()) { if (fanout < 0) { @@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) { return rem * (one - std::pow(one - u, one / n)); } -template +template inline T jth_sorted_uniform_random( - continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) { + seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) { const T u = seed.uniform(t + j * c); // https://mathematica.stackexchange.com/a/256707 rem -= invcdf(u, n, rem); @@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random( * should be put. Enough memory space should be allocated in advance. */ template < - bool NonUniform, bool Replace, typename ProbsType, typename PickedType, - int StackSize> -inline int64_t LaborPick( + bool NonUniform, bool Replace, typename ProbsType, SamplerType S, + typename PickedType, int StackSize> +inline std::enable_if_t LaborPick( int64_t offset, int64_t num_neighbors, int64_t fanout, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr) { + const torch::optional& probs_or_mask, SamplerArgs args, + PickedType* picked_data_ptr) { fanout = Replace ? fanout : std::min(fanout, num_neighbors); if (!NonUniform && !Replace && fanout >= num_neighbors) { std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); @@ -1504,8 +1516,8 @@ inline int64_t LaborPick( } AT_DISPATCH_INDEX_TYPES( args.indices.scalar_type(), "LaborPickMain", ([&] { - const index_t* local_indices_data = - args.indices.data_ptr() + offset; + const auto local_indices_data = + reinterpret_cast(args.indices.data_ptr()) + offset; if constexpr (Replace) { // [Algorithm] @mfbalin // Use a max-heap to get rid of the big random numbers and filter the