Skip to content

Commit

Permalink
fix pipeline forward return tensors (hpcaitech#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Jan 21, 2022
1 parent 6fb550a commit 708404d
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions colossalai/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Tuple, Union, Callable
import inspect
import torch.cuda
from typing import Callable, List, Tuple, Union

import colossalai.communication as comm
import torch.cuda
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger

from ._base_schedule import BaseSchedule


Expand Down Expand Up @@ -151,7 +152,7 @@ def forward_step(self, engine, input_tensor, return_tensors, return_output_label

if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label:
return_tensors.append(tuple((output_tensor, label)))
return_tensors.append((output_tensor, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
Expand Down Expand Up @@ -414,7 +415,7 @@ def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, ret

if gpc.is_pipeline_last_stage():
if return_output_label:
return_tensors.append(tuple(output_tensor, label))
return_tensors.append((output_tensor, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
Expand Down

0 comments on commit 708404d

Please sign in to comment.