From 569357fea0ae91e0c886f177d1b91dd16277e1c9 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 25 Jan 2022 22:20:54 +0800 Subject: [PATCH] add pytorch hooks (#179) * add pytorch hooks fix #175 * remove licenses in src code * add gpu memory tracer * replacing print with logger in ophooks. --- colossalai/builder/__init__.py | 12 +- colossalai/builder/builder.py | 32 ++- colossalai/engine/_base_engine.py | 22 +- colossalai/engine/ophooks/__init__.py | 115 +++++++++ colossalai/engine/ophooks/_base_ophook.py | 29 +++ .../engine/ophooks/_memtracer_ophook.py | 131 ++++++++++ colossalai/registry/__init__.py | 27 +- colossalai/trainer/_trainer.py | 239 ++++++++++-------- tests/test_config/test_load_config.py | 8 + 9 files changed, 480 insertions(+), 135 deletions(-) create mode 100644 colossalai/engine/ophooks/__init__.py create mode 100644 colossalai/engine/ophooks/_base_ophook.py create mode 100644 colossalai/engine/ophooks/_memtracer_ophook.py diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py index 4caef0c96b8e..c4840c24a530 100644 --- a/colossalai/builder/__init__.py +++ b/colossalai/builder/__init__.py @@ -1,10 +1,12 @@ -from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer, - build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, - build_gradient_handler) +from .builder import (build_schedule, build_lr_scheduler, build_model, + build_optimizer, build_layer, build_loss, build_hooks, + build_dataset, build_transform, build_data_sampler, + build_gradient_handler, build_ophooks) from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg __all__ = [ 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', - 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', - 'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg' + 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', + 'build_transform', 'build_data_sampler', 'build_gradient_handler', + 'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks' ] diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py index 71971321aefd..2c7eea999d2b 100644 --- a/colossalai/builder/builder.py +++ b/colossalai/builder/builder.py @@ -27,7 +27,7 @@ def build_from_registry(config, registry: Registry): """Returns an object constructed from `config`, the type of the object is specified by `registry`. - :param config: A python dict or a :class:`colossalai.context.Config` object + :param config: A python dict or a :class:`colossalai.context.Config` object containing information used in the construction of the return object :type config: dict or :class:`colossalai.context.colossalai.context.Config` :param registry: A registry specifying the type of the return object @@ -50,7 +50,8 @@ def build_from_registry(config, registry: Registry): obj = registry.get_module(mod_type)(**config_) except Exception as e: print( - f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + f'An error occurred when building {mod_type} from registry {registry.name}', + flush=True) raise e return obj @@ -69,7 +70,7 @@ def build_layer(config): def build_loss(config): - """Returns a loss function object of :class:`torch.autograd.Function` constructed + """Returns a loss function object of :class:`torch.autograd.Function` constructed from `config`. :param config: A python dict or a :class:`colossalai.context.Config` object @@ -94,7 +95,7 @@ def build_model(config): def build_dataset(config): - """Returns a dataset object of :class:`torch.utils.data.Dataset` constructed + """Returns a dataset object of :class:`torch.utils.data.Dataset` constructed from `config`. :param config: A python dict or a :class:`colossalai.context.Config` object @@ -107,13 +108,13 @@ def build_dataset(config): def build_optimizer(config, model): - """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, + """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, 'model' and 'params'. - :param config: A python dict or a :class:`colossalai.context.Config` object + :param config: A python dict or a :class:`colossalai.context.Config` object containing information used in the construction of the return object :type config: dict or :class:`colossalai.context.Config` - :param model: A model containing parameters for the optimizer + :param model: A model containing parameters for the optimizer :type model: :class:`nn.Module` :return: An object of :class:`torch.optim.Optimizer` :rtype: :class:`torch.optim.Optimizer` @@ -159,6 +160,19 @@ def build_hooks(config, trainer): return build_from_registry(config_, HOOKS) +def build_ophooks(config): + """Returns a hook object of :class:`BaseOpHook` constructed from `config`. + + :param config: A python dict or a :class:`colossalai.context.Config` object + containing information used in the construction of the return object + :type config: dict or :class:`colossalai.context.Config` + :return: An object of :class:`colossalai.trainer.hooks.BaseOpHook` + :rtype: :class:`colossalai.trainer.hooks.BaseOpHook` + """ + config_ = config.copy() + return build_from_registry(config_, OPHOOKS) + + def build_transform(config): """Returns a transformation object of :class:`torchvision.transforms` constructed from `config`. @@ -191,10 +205,10 @@ def build_data_sampler(config, dataset): def build_lr_scheduler(config, optimizer): - """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` + """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. - :param config: A python dict or a :class:`colossalai.context.Config` object + :param config: A python dict or a :class:`colossalai.context.Config` object containing information used in the construction of the return object :type config: dict or :class:`colossalai.context.Config` :param optimizer: An optimizer object containing parameters for the learning rate diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 136726211525..df201e6af555 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - from typing import List from torch.nn import Module from torch.nn.modules.loss import _Loss @@ -9,10 +8,11 @@ from colossalai.logging import get_dist_logger from torch import Tensor +from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook class Engine: - """Basic engine class for training and evaluation. It runs a specific process method + """Basic engine class for training and evaluation. It runs a specific process method :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. It controls a iteration in training. @@ -29,15 +29,14 @@ class Engine: :param verbose: whether to display log info :type verbose: bool """ - def __init__(self, model: Module, optimizer: Optimizer, criterion: _Loss, gradient_handlers: List = None, clip_grad_norm: float = 0.0, - verbose: bool = True - ): + ophook_list: List[BaseOpHook] = [], + verbose: bool = True): self._model = model self._optimizer = optimizer self._criterion = criterion @@ -54,6 +53,9 @@ def __init__(self, else: self._gradient_handlers = [] + self._ophook_list = ophook_list + register_ophooks_recursively(self._model, self._ophook_list) + @property def model(self): """Model attached to the engine""" @@ -87,7 +89,10 @@ def backward(self, loss: Tensor): :param loss: Loss value computed by a loss function :type loss: :class:`torch.Tensor` """ - return self.optimizer.backward(loss) + ret = self.optimizer.backward(loss) + for ophook in self._ophook_list: + ophook.post_iter() + return ret def backward_by_grad(self, tensor, grad): """Start backward propagation given the gradient of the output tensor @@ -97,7 +102,10 @@ def backward_by_grad(self, tensor, grad): :param grad: Gradient passed back to the output :type grad: :class:`torch.Tensor` """ - return self.optimizer.backward_by_grad(tensor, grad) + ret = self.optimizer.backward_by_grad(tensor, grad) + for ophook in self._ophook_list: + ophook.post_iter() + return ret def calc_loss(self, *args, **kwargs): """Compute the loss value diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py new file mode 100644 index 000000000000..abfe0a5819a0 --- /dev/null +++ b/colossalai/engine/ophooks/__init__.py @@ -0,0 +1,115 @@ +from ._base_ophook import BaseOpHook +from ._memtracer_ophook import MemTracerOpHook +import torch +from typing import List + +all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"] + + +# apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, functional, + backward_function, output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + module.applied_pre_backward = False + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + output = output.detach() + ctx.pre_backward_function = pre_backward_function + return output + + @staticmethod + def backward(ctx, *args): + """ + Args: + activation_grad of the next layer. + Returns: + grad of the input activation. + """ + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +def register_ophooks_recursively(module: torch.nn.Module, + ophook_list: List[BaseOpHook] = None, + name: str = ""): + r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" + assert isinstance(module, torch.nn.Module) + has_children = False + for child_name, child in module.named_children(): + register_ophooks_recursively(child, ophook_list, name + child_name) + has_children = True + + # Early return on modules with no parameters or buffers that + # are not in their children. + if (len(list(module.named_parameters(recurse=False))) == 0 + and len(list(module.named_buffers(recurse=False))) == 0): + return + + # return if the module has not childern. + if has_children: + return + + if ophook_list is not None: + for hook in ophook_list: + assert (isinstance(hook, BaseOpHook)) + + def _pre_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_fwd_exec(submodule, *args) + + def _post_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_fwd_exec(submodule, *args) + + def _pre_backward_module_hook(submodule, inputs, output): + def _run_before_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_bwd_exec(submodule, inputs, output) + + return _apply_to_tensors_only(submodule, PreBackwardFunction, + _run_before_backward_function, output) + + def _post_backward_module_hook(submodule, inputs): + def _run_after_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_bwd_exec(submodule, inputs) + + return _apply_to_tensors_only(submodule, PostBackwardFunction, + _run_after_backward_function, inputs) + + module.register_forward_pre_hook(_pre_forward_module_hook) + module.register_forward_hook(_post_forward_module_hook) + + module.register_forward_hook(_pre_backward_module_hook) + module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/engine/ophooks/_base_ophook.py b/colossalai/engine/ophooks/_base_ophook.py new file mode 100644 index 000000000000..e948a8cfbcc1 --- /dev/null +++ b/colossalai/engine/ophooks/_base_ophook.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +import torch + + +class BaseOpHook(ABC): + """This class allows users to add customized operations + before and after the execution of a PyTorch submodule""" + def __init__(self): + pass + + @abstractmethod + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + pass + + @abstractmethod + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + @abstractmethod + def post_iter(self): + pass diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py new file mode 100644 index 000000000000..3f5671230351 --- /dev/null +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -0,0 +1,131 @@ +import torch +from . import BaseOpHook +from concurrent.futures import ThreadPoolExecutor +from colossalai.registry import OPHOOKS +from colossalai.logging import get_dist_logger +from time import sleep, time +import psutil +import pickle + + +def get_cuda_memory_used(device): + """ + Get the free memory info of device. + Notice that for CPU, this function will return 1/N of the total free memory, + where N is the world size. + """ + ret = torch.cuda.memory_allocated() + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats() + return ret + + +class AsyncMemoryMonitor: + def __init__(self, power=10): + """ + An Async Mem Monitor runing during computing. + Sampling GPU memory usage of the current GPU dev + at interval of 1/(10**power) sec. + """ + self.keep_measuring = False + self.executor = ThreadPoolExecutor(max_workers=1) + self.monitor_thread = None + self.interval = 1 / (10**power) + self.time_stamps = [] + self.mem_stats = [] + + def set_interval(self, power: int): + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + self.keep_measuring = False + max_usage = self.monitor_thread.result() + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + dev = torch.device(f"cuda:{torch.cuda.current_device()}") + while self.keep_measuring: + max_usage = max( + max_usage, + get_cuda_memory_used(dev), + ) + sleep(self.interval) + return max_usage + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "wb") as f: + pickle.dump(self.state_dict(), f) + + +@OPHOOKS.register_module +class MemTracerOpHook(BaseOpHook): + def __init__(self, niter=5): + super().__init__() + self.async_mem_monitor = AsyncMemoryMonitor() + self._niter = niter + self._curiter = 0 + self._logger = get_dist_logger() + + def _isvalid(self, module): + return module.training and self._curiter < self._niter + + def niter(self): + return self._niter + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + if self._isvalid(module): + self.async_mem_monitor.finish() + self.async_mem_monitor.start() + self._logger.debug(f'FWD PRE {module.__class__.__name__}') + + def post_fwd_exec(self, module: torch.nn.Module, *args): + if self._isvalid(module): + self.async_mem_monitor.finish() + self._logger.debug(f'FWD POST {module.__class__.__name__}') + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + assert isinstance(module, torch.nn.Module) + if self._isvalid(module): + self.async_mem_monitor.finish() + self.async_mem_monitor.start() + self._logger.debug(f'BWD PRE {module.__class__.__name__}') + + def post_bwd_exec(self, module: torch.nn.Module, input): + assert isinstance(module, torch.nn.Module) + if self._isvalid(module): + self.async_mem_monitor.finish() + self._logger.debug(f'BWD POST {module.__class__.__name__}') + + def pre_iter(self): + pass + + def post_iter(self): + self.async_mem_monitor.finish() + if self._curiter == self._niter: + self._logger.info( + f'dump a memory statistics as pickle to ./memstats.pkl') + self.save_results("memstats.pkl") + self._curiter += 1 + + def save_results(self, filename): + self.async_mem_monitor.save(filename) diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py index 492b278a40f2..62b0bb08fae3 100644 --- a/colossalai/registry/__init__.py +++ b/colossalai/registry/__init__.py @@ -7,16 +7,17 @@ from .registry import Registry -LAYERS = Registry('layers', third_party_library=[nn]) -LOSSES = Registry('losses') -MODELS = Registry('models', third_party_library=[tv_models]) -OPTIMIZERS = Registry('optimizers', third_party_library=[optim, dist_optim]) -DATASETS = Registry('datasets', third_party_library=[tv_datasets]) -DIST_GROUP_INITIALIZER = Registry('dist_group_initializer') -GRADIENT_HANDLER = Registry('gradient_handler') -LOSSES = Registry('losses', third_party_library=[nn]) -HOOKS = Registry('hooks') -TRANSFORMS = Registry('transforms', third_party_library=[transforms]) -DATA_SAMPLERS = Registry('data_samplers') -LR_SCHEDULERS = Registry('lr_schedulers') -SCHEDULE = Registry('schedules') +LAYERS = Registry("layers", third_party_library=[nn]) +LOSSES = Registry("losses") +MODELS = Registry("models", third_party_library=[tv_models]) +OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim]) +DATASETS = Registry("datasets", third_party_library=[tv_datasets]) +DIST_GROUP_INITIALIZER = Registry("dist_group_initializer") +GRADIENT_HANDLER = Registry("gradient_handler") +LOSSES = Registry("losses", third_party_library=[nn]) +HOOKS = Registry("hooks") +TRANSFORMS = Registry("transforms", third_party_library=[transforms]) +DATA_SAMPLERS = Registry("data_samplers") +LR_SCHEDULERS = Registry("lr_schedulers") +SCHEDULE = Registry("schedules") +OPHOOKS = Registry("ophooks") diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 13a56eaf8a41..ebb3ac893884 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -1,8 +1,4 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - from typing import Union, List -from colossalai import engine from colossalai.context.parallel_mode import ParallelMode import torch @@ -11,17 +7,18 @@ from tqdm import tqdm from colossalai.core import global_context as gpc + from colossalai.engine import Engine from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule from colossalai.logging import DistributedLogger from colossalai.utils import MultiTimer from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage -from .hooks import BaseHook +from colossalai.trainer.hooks import BaseHook class Trainer: - """This a class tending for easy deployments of users' training and evaluation instead of - writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is + """This a class tending for easy deployments of users' training and evaluation instead of + writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is called `Trainer`. :param engine: Engine responsible for the process function @@ -33,12 +30,13 @@ class Trainer: :param logger: Logger used to record the whole training :type logger: :class:`colossalai.logging.DistributedLogger`, optional """ - - def __init__(self, - engine: Engine, - schedule: BaseSchedule = None, - timer: MultiTimer = None, - logger: DistributedLogger = None): + def __init__( + self, + engine: Engine, + schedule: BaseSchedule = None, + timer: MultiTimer = None, + logger: DistributedLogger = None, + ): # training-ralated params self._engine = engine self._max_epochs = 0 @@ -63,29 +61,28 @@ def __init__(self, # set schedule which specifies the training iteration for the engine if schedule is None: schedule = NonPipelineSchedule() - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - assert not isinstance(schedule, NonPipelineSchedule), \ - 'NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead.' + if (gpc.is_initialized(ParallelMode.PIPELINE) + and gpc.get_world_size(ParallelMode.PIPELINE) > 1): + assert not isinstance( + schedule, NonPipelineSchedule + ), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead." self._schedule = schedule self._schedule.pre_processing(engine) @property def cur_epoch(self): - """Returns the index of the current epoch. - """ + """Returns the index of the current epoch.""" return self._cur_epoch @cur_epoch.setter def cur_epoch(self, epoch: int): - """Set how many epochs have been processed. - """ + """Set how many epochs have been processed.""" # allow setter for training resumption self._cur_epoch = epoch @property def cur_step(self): - """Returns how many iteration steps have been processed. - """ + """Returns how many iteration steps have been processed.""" return self._cur_step @property @@ -131,8 +128,7 @@ def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: getattr(self._timer, action)(item, *args, **kwargs) def _reset_states(self) -> None: - """Clear trainer states - """ + """Clear trainer states""" self.states = dict() def _call_hooks(self, func, output=None): @@ -152,99 +148,122 @@ def _call_hooks(self, func, output=None): @staticmethod def _should_display_progress(display_progress: bool): - """ Only display progress on DP rank 0, TP rank 0 and PP last rank - """ - return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + """Only display progress on DP rank 0, TP rank 0 and PP last rank""" + return (display_progress and is_dp_rank_0() and is_tp_rank_0() + and is_no_pp_or_last_stage()) - def _train_epoch(self, - train_dataloader: DataLoader, - epoch: int = None, - display_progress: bool = False, - return_output_label: bool = True): + def _train_epoch( + self, + train_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): # set training state self._engine.train() data_iter = iter(train_dataloader) progress = range(self._steps_per_epoch) if display_progress: if epoch is None: - progress = tqdm(progress, desc='[Train]') + progress = tqdm(progress, desc="[Train]") else: - progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]') + progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") - self._call_hooks('before_train_epoch') - self._call_timer(action='start', item='Train-epoch') + self._call_hooks("before_train_epoch") + self._call_timer(action="start", item="Train-epoch") for i in progress: - self._call_hooks('before_train_iter') - self._call_timer(action='start', item='Train-step') + self._call_hooks("before_train_iter") + self._call_timer(action="start", item="Train-step") # run 1 training step self.engine.zero_grad() logits, label, loss = self.schedule.forward_backward_step( - self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label) + self.engine, + data_iter, + forward_only=False, + return_loss=True, + return_output_label=return_output_label, + ) self.engine.step() - self._call_timer(action='stop', item='Train-step', keep_in_history=True) - self._call_hooks('after_train_iter', output=(logits, label, loss)) + self._call_timer(action="stop", + item="Train-step", + keep_in_history=True) + self._call_hooks("after_train_iter", output=(logits, label, loss)) self._cur_step += 1 if display_progress: - if 'step_metrics' in self.states: - progress.set_postfix(**self.states['step_metrics']) + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) # stop when max iter is reached if self._exceed_max_step(): break - self._call_timer(action='stop', item='Train-epoch', keep_in_history=True) - self._call_hooks('after_train_epoch') - self._call_timer(action='reset', item='Train-epoch') + self._call_timer(action="stop", + item="Train-epoch", + keep_in_history=True) + self._call_hooks("after_train_epoch") + self._call_timer(action="reset", item="Train-epoch") - def _eval(self, - test_dataloader: DataLoader, - epoch: int = None, - display_progress: bool = False, - return_output_label: bool = True): + def _eval( + self, + test_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): # switch engine status self._engine.eval() data_iter = iter(test_dataloader) num_steps = len(test_dataloader) - self._call_hooks('before_test') + self._call_hooks("before_test") # prepare progress bar progress = range(num_steps) if display_progress: - desc = 'Evaluation' + desc = "Evaluation" if epoch is not None: - desc = '[Epoch %d / Test]' % epoch + desc = "[Epoch %d / Test]" % epoch progress = tqdm(progress, desc=desc) - self._call_hooks('before_test_epoch') - self._call_timer(action='start', item='Test-epoch') + self._call_hooks("before_test_epoch") + self._call_timer(action="start", item="Test-epoch") with torch.no_grad(): for _ in progress: - self._call_hooks('before_test_iter') - self._call_timer(action='start', item='Test-step') + self._call_hooks("before_test_iter") + self._call_timer(action="start", item="Test-step") logits, label, loss = self.schedule.forward_backward_step( - self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label) - self._call_timer(action='stop', item='Test-step', keep_in_history=True) - self._call_hooks('after_test_iter', + self.engine, + data_iter, + forward_only=True, + return_loss=True, + return_output_label=return_output_label, + ) + self._call_timer(action="stop", + item="Test-step", + keep_in_history=True) + self._call_hooks("after_test_iter", output=(logits, label, loss)) if display_progress: - if 'step_metrics' in self.states: - progress.set_postfix(**self.states['step_metrics']) + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) - self._call_timer(action='stop', item='Test-epoch', keep_in_history=True) - self._call_hooks('after_test_epoch') - self._call_hooks('after_test') - self._call_timer(action='reset', item='Test-step') - self._call_timer(action='reset', item='Test-epoch') + self._call_timer(action="stop", + item="Test-epoch", + keep_in_history=True) + self._call_hooks("after_test_epoch") + self._call_hooks("after_test") + self._call_timer(action="reset", item="Test-step") + self._call_timer(action="reset", item="Test-epoch") def _exceed_max_step(self): return self._max_steps is not None and self._cur_step >= self._max_steps - def fit(self, + def fit( + self, train_dataloader: DataLoader, epochs: int, max_steps: int = None, @@ -253,7 +272,7 @@ def fit(self, hooks: List[BaseHook] = None, display_progress: bool = False, return_output_label: bool = True, - ): + ): """Trains the model to fit training data. :param train_dataloader: DataLoader in training @@ -290,7 +309,9 @@ def fit(self, # reset hooks self._reset_states() if hooks is not None: - assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}' + assert isinstance( + hooks, list + ), f"expected argument hooks be to list, but got {type(hooks)}" else: hooks = [] self.hooks = hooks @@ -298,13 +319,16 @@ def fit(self, if self._verbose: for hook in self.hooks: self._logger.info( - f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0]) - self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) - self._call_hooks('after_hook_is_attached') + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info( + "Lower value means higher priority for calling hook function", + ranks=[0]) + self._call_hooks("after_hook_is_attached") - # start train self._engine.train() - self._call_hooks('before_train') + self._call_hooks("before_train") # recover step value if resuming training last_epoch = self._cur_epoch @@ -317,16 +341,17 @@ def fit(self, train_dataloader=train_dataloader, epoch=epoch, display_progress=display_progress, - return_output_label=return_output_label + return_output_label=return_output_label, ) # start eval if should_test and epoch % test_interval == 0: - self._eval(test_dataloader=test_dataloader, - display_progress=display_progress, - epoch=epoch, - return_output_label=return_output_label - ) + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + epoch=epoch, + return_output_label=return_output_label, + ) self._cur_epoch += 1 @@ -334,16 +359,19 @@ def fit(self, if self._exceed_max_step(): self._logger.info( f"Max number of steps {max_steps} has been reached, training is stopped automatically", - ranks=[0]) + ranks=[0], + ) break - self._call_hooks('after_train') - self._call_timer('reset', 'Train-epoch') - - def evaluate(self, - test_dataloader: DataLoader, - hooks: List[BaseHook] = None, - display_progress: bool = False, - return_output_label: bool = True): + self._call_hooks("after_train") + self._call_timer("reset", "Train-epoch") + + def evaluate( + self, + test_dataloader: DataLoader, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, + ): """Evaluates the model with testing data. :param test_dataloader: DataLoader in testing @@ -362,7 +390,9 @@ def evaluate(self, # reset hooks self._reset_states() if hooks is not None: - assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}' + assert isinstance( + hooks, list + ), f"expected argument hooks be to list, but got {type(hooks)}" else: hooks = [] self.hooks = hooks @@ -370,15 +400,20 @@ def evaluate(self, if self._verbose: for hook in self.hooks: self._logger.info( - f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0]) - self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) - self._call_hooks('after_hook_is_attached') + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info( + "Lower value means higher priority for calling hook function", + ranks=[0]) + self._call_hooks("after_hook_is_attached") # eval - self._eval(test_dataloader=test_dataloader, - display_progress=display_progress, - return_output_label=return_output_label - ) + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + return_output_label=return_output_label, + ) def predict(self, data: Union[Tensor, List[Tensor]]): """Uses trained model to make a prediction for a tensor or a tensor list. @@ -399,6 +434,8 @@ def predict(self, data: Union[Tensor, List[Tensor]]): # for compatibility with schedule simple_dataloader = [(data, None)] data_iter = iter(simple_dataloader) - output, _, _ = self.schedule.forward_backward_step( - self.engine, data_iter, forward_only=True, return_loss=False) + output, _, _ = self.schedule.forward_backward_step(self.engine, + data_iter, + forward_only=True, + return_loss=False) return output diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 550af2a4ae81..2c4543b750d5 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -6,6 +6,7 @@ import pytest from colossalai.context.config import Config +from colossalai.builder import build_ophooks @pytest.mark.cpu @@ -17,3 +18,10 @@ def test_load_config(): assert config.train_data.dataset, 'cannot access grandchild attribute' assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' + + +@pytest.mark.cpu +def test_load_ophooks(): + dict = {'type': 'MemTracerOpHook', 'niter': 2} + ophook = build_ophooks(dict) + assert ophook.niter() == 2