Skip to content

Commit

Permalink
[Inductor] Support parallel reduction for GroupNorm
Browse files Browse the repository at this point in the history
ghstack-source-id: 7df55c4453e85b8e735c214c13717459fec484e8
Pull Request resolved: #144020
  • Loading branch information
jiayisunx committed Jan 2, 2025
1 parent 92d8965 commit 9e260b2
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,18 +2134,18 @@ def codegen_loops_impl(self, loop_nest, code, worksharing):
threads = parallel_num_threads()
assert self.call_ranges is not None
if isinstance(loop_nest.kernel, OuterLoopFusedKernel):
par_depth = loop_nest.kernel.decide_parallel_depth(
par_depth, start_depth = loop_nest.kernel.decide_parallel_depth(
loop_nest.max_parallel_depth(), threads
)
else:
par_depth = self.decide_parallel_depth(
par_depth, start_depth = self.decide_parallel_depth(
loop_nest.max_parallel_depth(), threads
)

is_reduction_only = loop_nest.is_reduction_only()
is_reduction_only = loop_nest.is_reduction_only(start_depth)
with contextlib.ExitStack() as stack:
if par_depth:
if loop_nest.is_reduction_only():
if is_reduction_only:
# need to close the worksharing scope to define reduction vars outside it
worksharing.close()
else:
Expand All @@ -2158,7 +2158,7 @@ def codegen_loops_impl(self, loop_nest, code, worksharing):
def gen_kernel(_loop_nest: LoopNest):
def is_parallel_reduction():
assert _loop_nest.loops
root = _loop_nest.loops[0]
root = _loop_nest.loops[start_depth]
return root.is_reduction and root.parallel

kernel = _loop_nest.get_kernel()
Expand Down Expand Up @@ -2279,7 +2279,9 @@ def assert_function(self) -> str:

def decide_parallel_depth(self, max_parallel_depth, threads):
assert self.call_ranges is not None
ranges = self.call_ranges[:max_parallel_depth]
ranges = self.call_ranges[
max_parallel_depth[1] : (max_parallel_depth[0] + max_parallel_depth[1])
]
seq = self.size_hint()
par = 1
depth = 0
Expand All @@ -2298,7 +2300,7 @@ def decide_parallel_depth(self, max_parallel_depth, threads):
# to manage the serial vs. parallel.
if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
depth = 1
return depth
return depth, max_parallel_depth[1]

@contextlib.contextmanager
def write_to_suffix(self):
Expand Down Expand Up @@ -4193,11 +4195,17 @@ def decide_parallel_depth(self, max_parallel_depth, threads) -> int:
call_ranges = kernel.call_ranges
assert call_ranges is not None
kernels_parallel_depth.append(
kernel.decide_parallel_depth(len(call_ranges), threads)
kernel.decide_parallel_depth(
(len(call_ranges) - max_parallel_depth[1], max_parallel_depth[1]),
threads,
)[0]
)
return min(
max_parallel_depth,
max(kernels_parallel_depth),
return (
min(
max_parallel_depth[0],
max(kernels_parallel_depth),
),
max_parallel_depth[1],
)


Expand Down Expand Up @@ -5187,27 +5195,39 @@ def max_parallel_depth(self):
return 0

max_depth = 0
start_depth = 0
is_reduction = self.loops[0].is_reduction
for loop in self.loops:
if loop.is_reduction != is_reduction:
break
max_depth += 1
return max_depth
if (
len(self.loops) > 2
and not self.loops[0].is_reduction
and self.loops[1].is_reduction
):
if (
isinstance(self.loops[0].size, sympy.Integer)
and isinstance(self.loops[1].size, sympy.Integer)
and self.loops[0].size * 100 < self.loops[1].size
):
start_depth = 1
return max_depth, start_depth

def is_reduction_only(self):
def is_reduction_only(self, start_depth):
"""
Whether all the loops are for reduction. Reduction loops
are always the inner most ones.
"""
return self.loops is not None and self.loops[0].is_reduction
return self.loops is not None and self.loops[start_depth].is_reduction

def mark_parallel(self, par_depth):
assert (
par_depth <= self.max_parallel_depth()
par_depth <= self.max_parallel_depth()[0]
), "Parallel depth cannot exceed the maximal allowed parallel depth"
assert self.loops is not None
assert len(self.loops) >= par_depth
loop = self.loops[0]
loop = self.loops[self.max_parallel_depth()[1]]
loop.parallel = par_depth
for i in range(1, par_depth):
self.loops[i].collapsed = True
Expand Down

0 comments on commit 9e260b2

Please sign in to comment.