Skip to content

Commit

Permalink
[utils] add synchronized cuda memory monitor (hpcaitech#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Apr 13, 2022
1 parent e6212f5 commit 340e59f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 110 deletions.
8 changes: 4 additions & 4 deletions colossalai/trainer/hooks/_mem_tracer_hook.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from cgitb import Hook
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook


@HOOKS.register_module
class MemTraceHook(BaseHook):
"""Save memory stats and pass it to states
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""

def __init__(
self,
priority: int = 0,
Expand All @@ -36,9 +36,9 @@ def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor)
def before_test_iter(self, trainer):
self._memory_monitor.start()
return super().before_test(trainer)

def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss)
return super().after_test_iter(trainer, output, label, loss)
4 changes: 2 additions & 2 deletions colossalai/utils/memory_tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .async_memtracer import AsyncMemoryMonitor
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector

__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector']
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']
Original file line number Diff line number Diff line change
@@ -1,103 +1,142 @@
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import pickle

import torch

from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device


class AsyncMemoryMonitor:
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""

def __init__(self, power: int = 10):
self.keep_measuring = False

current_device = get_current_device()

def _set_cuda_device():
torch.cuda.set_device(current_device)

self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None
self.interval = 1 / (10**power)
self.time_stamps = []
self.mem_stats = []

def __len__(self):
return len(self.mem_stats)

def set_interval(self, power: int):
self.clear()
self.interval = 1 / (10**power)

def is_measuring(self):
return self.keep_measuring

def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)

def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage

def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage

@property
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}

def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self.state_dict(), f)

def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import json

import torch

from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device


class MemoryMonitor:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""

def __init__(self):
self.time_stamps = []
self.mem_stats = []

def __len__(self):
return len(self.mem_stats)

@abstractmethod
def start(self):
pass

@abstractmethod
def finish(self):
pass

def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}

def save(self, filename):
with open(filename, "w") as f:
json.dump(self.state_dict(), f)

def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()


class AsyncMemoryMonitor(MemoryMonitor):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""

def __init__(self, power: int = 10):
super().__init__()
self.keep_measuring = False

current_device = get_current_device()

def _set_cuda_device():
torch.cuda.set_device(current_device)

self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None
self.interval = 1 / (10**power)

def set_interval(self, power: int):
self.clear()
self.interval = 1 / (10**power)

def is_measuring(self):
return self.keep_measuring

def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)

def finish(self):
if self.keep_measuring is False:
return 0

self.keep_measuring = False
max_usage = self.monitor_thread.result()

self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage

def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage


class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""

def __init__(self, power: int = 10):
super().__init__()

def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()

def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage
2 changes: 1 addition & 1 deletion colossalai/utils/memory_tracer/memstats_collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import torch
import time
from typing import List
Expand Down

0 comments on commit 340e59f

Please sign in to comment.