Skip to content

Commit

Permalink
[doc] improved assertion messages in trainer (hpcaitech#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored Apr 26, 2022
1 parent 7a64fae commit 1c34382
Showing 1 changed file with 46 additions and 69 deletions.
115 changes: 46 additions & 69 deletions colossalai/trainer/_trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from typing import Union, List
from colossalai.context.parallel_mode import ParallelMode
from typing import Union, List, Any

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from colossalai.core import global_context as gpc

from colossalai.engine import Engine
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
Expand Down Expand Up @@ -53,11 +49,12 @@ class Trainer:
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
"""

def __init__(
self,
engine: Engine,
timer: MultiTimer = None,
logger: DistributedLogger = None,
self,
engine: Engine,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
# training-ralated params
self._engine = engine
Expand Down Expand Up @@ -154,15 +151,14 @@ def _call_hooks(self, func, output=None):
@staticmethod
def _should_display_progress(display_progress: bool):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
and is_no_pp_or_last_stage())
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())

def _train_epoch(
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# set training state
self._engine.train()
Expand All @@ -189,9 +185,7 @@ def _train_epoch(
return_output_label=return_output_label,
)
self.engine.step()
self._call_timer(action="stop",
item="Train-step",
keep_in_history=True)
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss))

self._cur_step += 1
Expand All @@ -204,18 +198,16 @@ def _train_epoch(
if self._exceed_max_step():
break

self._call_timer(action="stop",
item="Train-epoch",
keep_in_history=True)
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch")

def _eval(
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# switch engine status
self._engine.eval()
Expand Down Expand Up @@ -244,19 +236,14 @@ def _eval(
return_loss=True,
return_output_label=return_output_label,
)
self._call_timer(action="stop",
item="Test-step",
keep_in_history=True)
self._call_hooks("after_test_iter",
output=(logits, label, loss))
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
self._call_hooks("after_test_iter", output=(logits, label, loss))

if display_progress:
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])

self._call_timer(action="stop",
item="Test-epoch",
keep_in_history=True)
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
self._call_hooks("after_test_epoch")
self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step")
Expand All @@ -266,15 +253,15 @@ def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step >= self._max_steps

def fit(
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
r"""Trains the model to fit training data.
Expand Down Expand Up @@ -303,9 +290,11 @@ def fit(
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"

for hook in hooks:
assert isinstance(hook, BaseHook), \
f'expected the hook to be of type BaseHook, but got {type(hook)}'
else:
hooks = []
self.hooks = hooks
Expand All @@ -316,9 +305,7 @@ def fit(
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")

self._engine.train()
Expand Down Expand Up @@ -360,11 +347,11 @@ def fit(
self._call_timer("reset", "Train-epoch")

def evaluate(
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
"""Evaluates the model with testing data.
Expand All @@ -381,9 +368,7 @@ def evaluate(
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
else:
hooks = []
self.hooks = hooks
Expand All @@ -394,9 +379,7 @@ def evaluate(
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")

# eval
Expand All @@ -406,7 +389,7 @@ def evaluate(
return_output_label=return_output_label,
)

def predict(self, data: Union[Tensor, List[Tensor]]):
def predict(self, data: Union[Any, List[Any]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
Args:
Expand All @@ -416,17 +399,11 @@ def predict(self, data: Union[Tensor, List[Tensor]]):
:class:`torch.tensor`: The output of model as the prediction
"""
# predict without labels
if isinstance(data, (list, tuple)):
assert isinstance(data[0], Tensor)
else:
assert isinstance(data, Tensor)
self._engine.eval()

# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.engine.execute_schedule(data_iter,
forward_only=True,
return_loss=False)
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
return output

0 comments on commit 1c34382

Please sign in to comment.