[BUG] MoE load balancing loss is accumulated twice when using activation checkpointing #1330
Description
Describe the bug
Load balancing loss is accumulated twice when using activation checkpointing
To Reproduce
Train from scratch with / without --moe-layer-recompute
, setting --moe-router-load-balancing-type aux_loss
Expected behavior
Load balancing loss should be the same in the two settings (and should be slightly higher than 1 which means fully balanced)
Stack trace/logs
-
without
--moe-layer-recompute
:
iteration 10: load_balancing_loss: 1.091395E+00
iteration 20: load_balancing_loss: 1.096082E+00
iteration 30: load_balancing_loss: 1.037049E+00 -
with
--moe-layer-recompute
:
iteration 10: load_balancing_loss: 2.202137E+00
iteration 20: load_balancing_loss: 2.298303E+00
iteration 30: load_balancing_loss: 2.120842E+00
Environment (please complete the following information):
- Megatron-LM d4e72c0
- PyTorch 2.4.1
- CUDA 12.1
- NCCL 2.20.5
Proposed fix
Replace if self.training
with if self.training and torch.is_grad_enabled():
.
Reason: When using activation checkpointing with --moe-layer-recompute
, the forward function is executed twice. This leads to the load balancing loss being accumulated twice in TopKRouter.aux_loss_load_balancing
within megatron/core/transformer/moe/router.py
if the condition is only if self.training:
. By changing the condition to if self.training and torch.is_grad_enabled():
, the accumulation during the first pass (where gradients are not enabled) is prevented, while ensuring the standard training process without --moe-layer-recompute
remains unaffected.
A similar issue occurs with z_loss.
The fix is included in the PR #1331.
Additional context
N/A