Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] moving memtracer to gemini #801

Merged
merged 17 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[log] add tflops logs
  • Loading branch information
feifeibear committed Apr 18, 2022
commit 2b23a3ea5f3167c00e0ac674edb1455dbd04a5f3
9 changes: 9 additions & 0 deletions colossalai/trainer/hooks/_commons_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.{prec}g}'
return val
58 changes: 22 additions & 36 deletions colossalai/trainer/hooks/_log_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook


def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.{prec}g}'
return val
from ._commons_ import _format_number


class LogByEpochHook(BaseHook):
Expand All @@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
depend on the hooks order in the hook list.
"""

def __init__(self,
logger,
interval: int = 1,
priority: int = 1):
def __init__(self, logger, interval: int = 1, priority: int = 1):
super().__init__(priority)
self.logger = logger
self._interval = interval
Expand All @@ -63,14 +53,12 @@ def __init__(self, priority: int = 10):
def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()

def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()


@HOOKS.register_module
Expand All @@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
depend on the hooks order in the hook list.
"""

def __init__(self,
logger,
interval: int = 1,
priority: int = 10) -> None:
def __init__(self, logger, interval: int = 1, priority: int = 10) -> None:
super().__init__(logger, interval, priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()

def _get_str(self, trainer, mode):
msg = []
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append(
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ' | '.join(msg)
return msg

Expand Down Expand Up @@ -130,12 +114,13 @@ class TensorboardHook(BaseHook):
depend on the hooks order in the hook list.
"""

def __init__(self,
log_dir: str,
ranks: List = None,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
priority: int = 10,
) -> None:
def __init__(
self,
log_dir: str,
ranks: List = None,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
priority: int = 10,
) -> None:
super().__init__(priority=priority)
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -183,7 +168,7 @@ def _log_by_iter(self, trainer, mode: str):
def _log_by_epoch(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
val, tflops_value = metric_calculator.get_accumulated_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)

Expand Down Expand Up @@ -280,13 +265,14 @@ class LogMemoryByEpochHook(LogByEpochHook):
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
"""

def __init__(self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False, # no reference
) -> None:
def __init__(
self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False, # no reference
) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
Expand Down
39 changes: 25 additions & 14 deletions colossalai/trainer/hooks/_metric_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from colossalai.utils import get_current_device, is_no_pp_or_last_stage

from ._base_hook import BaseHook
from ._commons_ import _format_number


class Metric(ABC):
Expand Down Expand Up @@ -51,7 +52,7 @@ def update(self, *args, **kwargs) -> None:
pass

@abstractmethod
def get_last_step_value(self):
def get_last_step_value(self) -> str:
"""Returns the metric value in the last iteration.
"""
pass
Expand Down Expand Up @@ -120,10 +121,10 @@ def get_accumulated_value(self):
self.accum_loss.div_(self.count)
return self.accum_loss.item()

def get_last_step_value(self):
def get_last_step_value(self) -> str:
"""Returns :attr:`last_step_loss`.
"""
return self.last_step_loss
return str(self.last_step_loss)

@staticmethod
def is_better(a, b):
Expand All @@ -148,8 +149,8 @@ def reset(self) -> None:
def update(self, lr) -> None:
self.lr = lr

def get_last_step_value(self):
return self.lr
def get_last_step_value(self) -> str:
return str(self.lr)

def get_accumulated_value(self):
return self.lr
Expand Down Expand Up @@ -203,10 +204,10 @@ def update(self, logits, targets, batch_size) -> None:
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct

def get_last_step_value(self):
def get_last_step_value(self) -> str:
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return (self.last_step_correct / self.last_step_sum).item()
return str(_format_number((self.last_step_correct / self.last_step_sum).item()))

def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
Expand Down Expand Up @@ -322,14 +323,16 @@ class ThroughputMetric(Metric):
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
"""
def __init__(self, epoch_only: bool, ignored_steps: int = 0):

def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflops_per_step=0):
super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device())
self._tflops_per_step = tflops_per_step

def reset(self) -> None:
# self.cur_steps = 0
Expand All @@ -346,13 +349,18 @@ def update(self, num_samples, time) -> None:
self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time

def get_last_step_value(self):
def get_last_step_value(self) -> str:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item()

def get_accumulated_value(self):
samplePerSec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
if tflops > 0:
tflops = _format_number(self._tflops_per_step / (self.last_step_used_time.item() + 1e-12))
return f"{samplePerSec} samplePerSec, {tflops} Tflops"
else:
return f"{samplePerSec} samplePerSec"

def get_accumulated_value(self) -> float:
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
Expand All @@ -373,6 +381,7 @@ class ThroughputHook(MetricHook):
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""

def __init__(self, ignored_steps: int = 0, priority: int = 10):
super().__init__(priority)
self.ignored_steps = ignored_steps
Expand All @@ -392,12 +401,14 @@ def before_train_epoch(self, trainer):

def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Train-step').get_elapsed_time())

def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()

def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Test-step').get_elapsed_time())