Skip to content

Commit

Permalink
[profiler] add MemProfiler (hpcaitech#356)
Browse files Browse the repository at this point in the history
* add memory trainer hook

* fix bug

* add memory trainer hook

* fix import bug

* fix import bug

* add trainer hook

* fix hpcaitech#370 git log bug

* modify `to_tensorboard` function to support better output

* remove useless output

* change the name of `MemProfiler`

* complete memory profiler

* replace error with warning

* finish trainer hook

* modify interface of MemProfiler

* modify `__init__.py` in profiler

* remove unnecessary pass statement

* add usage to doc string

* add usage to trainer hook

* new location to store temp data file
  • Loading branch information
Jie Zhu authored Mar 29, 2022
1 parent fb841dd commit 73d3661
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README-zh-Hans.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,4 @@ class MLP_2D(nn.Module):
}
```

<p align="right">(<a href="#top">返回顶端</a>)</p>
<p align="right">(<a href="#top">返回顶端</a>)</p>
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,4 @@ class MLP_2D(nn.Module):
}
```

<p align="right">(<a href="#top">back to top</a>)</p>
<p align="right">(<a href="#top">back to top</a>)</p>
27 changes: 24 additions & 3 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from asyncio.log import logger
from typing import List
from torch.nn import Module
from torch.nn.modules.loss import _Loss
Expand All @@ -9,9 +10,9 @@
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from typing import Optional
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler

from colossalai.logging import get_dist_logger

class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
Expand Down Expand Up @@ -64,6 +65,11 @@ def __init__(self,
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list)

@property
def ophooks(self):
"""show current activated ophooks"""
return self._ophook_list

@property
def model(self):
"""Model attached to the engine"""
Expand All @@ -79,6 +85,21 @@ def criterion(self):
"""Criterion attached to the engine"""
return self._criterion

def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook"""
# whether this hook exist
for h in self._ophook_list:
if type(h) == type(ophook):
logger = get_dist_logger()
logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
self._ophook_list.append(ophook)
register_ophooks_recursively(self._model, self._ophook_list)

def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
"""remove hook"""
logger = get_dist_logger()
logger.warning(f"removing hooks is currently not supported")

def zero_grad(self):
"""Set the gradient of parameters to zero
"""
Expand Down Expand Up @@ -150,4 +171,4 @@ def eval(self):
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()
self._model.eval()
17 changes: 11 additions & 6 deletions colossalai/engine/ophooks/_memtracer_ophook.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
import pickle
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
import torch
from colossalai.engine.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc

from typing import Union
from colossalai.utils.memory_tracer import AsyncMemoryMonitor

import os
import math


Expand Down Expand Up @@ -103,12 +106,14 @@ def post_iter(self):
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
# output file info
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
self.save_results()
home_dir = Path.home()
with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
pickle.dump(self.async_mem_monitor.state_dict, f)
self._count += 1
self._logger.debug(f"data file has been refreshed {self._count} times")
# finish a iteration
self._curiter += 1

def save_results(self):
datafile = f"{self._data_prefix}-{self._rank}.pkl"
self.async_mem_monitor.save(datafile)
def save_results(self, data_file: Union[str, Path]):
with open(data_file, "w") as f:
f.write(json.dumps(self.async_mem_monitor.state_dict))
44 changes: 44 additions & 0 deletions colossalai/trainer/hooks/_mem_tracer_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from cgitb import Hook
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook

@HOOKS.register_module
class MemTraceHook(BaseHook):
"""Save memory stats and pass it to states
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
def __init__(
self,
priority: int = 0,
) -> None:
super().__init__(priority=priority)
self._memory_monitor = AsyncMemoryMonitor()

def after_hook_is_attached(self, trainer):
# Initialize the data
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict

def before_train_iter(self, trainer):
self._memory_monitor.start()
return super().before_train_iter(trainer)

def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_train_iter(trainer, output, label, loss)

def before_test_iter(self, trainer):
self._memory_monitor.start()
return super().before_test(trainer)

def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss)
2 changes: 1 addition & 1 deletion colossalai/utils/memory_tracer/async_memtracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _measure_usage(self):
sleep(self.interval)
return max_usage

@property
def state_dict(self):
return {
"time_stamps": self.time_stamps,
Expand All @@ -94,7 +95,6 @@ def state_dict(self):

def save(self, filename):
with open(filename, "wb") as f:
print(self.state_dict())
pickle.dump(self.state_dict(), f)

def clear(self):
Expand Down
5 changes: 4 additions & 1 deletion colossalai/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .comm_profiler import CommProfiler
from .pcie_profiler import PcieProfiler
from .prof_utils import ProfilerContext
from .prof_utils import ProfilerContext, BaseProfiler
from .mem_profiler import MemProfiler

__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
50 changes: 50 additions & 0 deletions colossalai/utils/profiler/mem_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path
from typing import Union
from colossalai.engine import Engine
from torch.utils.tensorboard import SummaryWriter
from colossalai.engine.ophooks import MemTracerOpHook
from colossalai.utils.profiler import BaseProfiler


class MemProfiler(BaseProfiler):
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration
To use this profiler, you need to pass an `engine` instance. And the usage is same like
CommProfiler.
mm_prof = MemProfiler(engine)
with ProfilerContext([mm_prof]) as prof:
writer = SummaryWriter("mem")
engine.train()
...
prof.to_file("./log")
prof.to_tensorboard(writer)
"""

def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None:
super().__init__(profiler_name="MemoryProfiler", priority=0)
self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate)
self._engine = engine

def enable(self) -> None:
self._engine.add_hook(self._mem_tracer)

def disable(self) -> None:
self._engine.remove_hook(self._mem_tracer)

def to_tensorboard(self, writer: SummaryWriter) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
for info, i in enumerate(stats):
writer.add_scalar(
"memory_usage/GPU",
info,
i
)

def to_file(self, data_file: Path) -> None:
self._mem_tracer.save_results(data_file)

def show(self) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
print(stats)

0 comments on commit 73d3661

Please sign in to comment.