forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorboard_hook.py
273 lines (241 loc) · 10.4 KB
/
tensorboard_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch
from classy_vision import tasks
from classy_vision.generic.distributed_util import is_primary
from classy_vision.hooks.classy_hook import ClassyHook
try:
from torch.utils.tensorboard import SummaryWriter # noqa F401
tb_available = True
except ImportError:
# Make sure that the type hint is not blocking
# on a non-TensorBoard aware platform
from typing import TypeVar
SummaryWriter = TypeVar("SummaryWriter")
tb_available = False
BYTE_TO_MiB = 2 ** 20
class SSLTensorboardHook(ClassyHook):
"""
SSL Specific variant of the Classy Vision tensorboard hook
"""
on_loss_and_meter = ClassyHook._noop
on_backward = ClassyHook._noop
on_start = ClassyHook._noop
on_end = ClassyHook._noop
on_step = ClassyHook._noop
def __init__(
self,
tb_writer: SummaryWriter,
log_params: bool = False,
log_params_every_n_iterations: int = -1,
log_params_gradients: bool = False,
) -> None:
"""The constructor method of SSLTensorboardHook.
Args:
tb_writer: `Tensorboard SummaryWriter <https://tensorboardx.
readthedocs.io/en/latest/tensorboard.html#tensorboardX.
SummaryWriter>`_ instance
log_params (bool): whether to log model params to tensorboard
log_params_every_n_iterations (int): frequency at which parameters
should be logged to tensorboard
log_params_gradients (bool): whether to log params gradients as well
to tensorboard.
"""
super().__init__()
if not tb_available:
raise RuntimeError(
"tensorboard not installed, cannot use SSLTensorboardHook"
)
logging.info("Setting up SSL Tensorboard Hook...")
self.tb_writer = tb_writer
self.log_params = log_params
self.log_params_every_n_iterations = log_params_every_n_iterations
self.log_params_gradients = log_params_gradients
logging.info(
f"Tensorboard config: log_params: {self.log_params}, "
f"log_params_freq: {self.log_params_every_n_iterations}, "
f"log_params_gradients: {self.log_params_gradients}"
)
def on_forward(self, task: "tasks.ClassyTask") -> None:
"""
Called after every forward if tensorboard hook is enabled.
Logs the model parameters if the training iteration matches the
logging frequency.
"""
if not self.log_params:
return
if (
self.log_params_every_n_iterations > 0
and is_primary()
and task.train
and task.iteration % self.log_params_every_n_iterations == 0
):
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}", parameter, global_step=task.iteration
)
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""
Called at the start of every epoch if the tensorboard hook is
enabled.
Logs the model parameters once at the beginning of training only.
"""
if not self.log_params:
return
# log the parameters just once, before training starts
if is_primary() and task.train and task.train_phase_idx == 0:
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}", parameter, global_step=-1
)
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Called at the end of every epoch if the tensorboard hook is
enabled.
Log model parameters and/or parameter gradients as set by user
in the tensorboard configuration. Also resents the CUDA memory counter.
"""
# Log train/test accuracy
if is_primary():
phase_type = "Training" if task.train else "Testing"
for meter in task.meters:
if "accuracy" in meter.name:
for top_n, accuracies in meter.value.items():
for i, acc in accuracies.items():
tag_name = f"{phase_type}/Accuracy_" f" {top_n}_Output_{i}"
self.tb_writer.add_scalar(
tag=tag_name,
scalar_value=round(acc, 5),
global_step=task.train_phase_idx,
)
if not (self.log_params or self.log_params_gradients):
return
if is_primary() and task.train:
# Log the weights and bias at the end of the epoch
if self.log_params:
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}",
parameter,
global_step=task.train_phase_idx,
)
# Log the parameter gradients at the end of the epoch
if self.log_params_gradients:
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
self.tb_writer.add_histogram(
f"Gradients/{name}",
parameter.grad,
global_step=task.train_phase_idx,
)
except ValueError:
logging.info(
f"Gradient histogram empty for {name}, "
f"iteration {task.iteration}. Unable to "
f"log gradient."
)
# Reset the GPU Memory counter
if torch.cuda.is_available():
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
def on_update(self, task: "tasks.ClassyTask") -> None:
"""
Called after every parameters update if tensorboard hook is enabled.
Logs the parameter gradients if they are being set to log,
log the scalars like training loss, learning rate, average training
iteration time, batch size per gpu, img/sec/gpu, ETA, gpu memory used,
peak gpu memory used.
"""
if not is_primary():
return
iteration = task.iteration
if (
self.log_params_every_n_iterations > 0
and self.log_params_gradients
and task.train
and iteration % self.log_params_every_n_iterations == 0
):
logging.info(f"Logging Parameter gradients. Iteration {iteration}")
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
self.tb_writer.add_histogram(
f"Gradients/{name}",
parameter.grad,
global_step=task.iteration,
)
except ValueError:
logging.info(
f"Gradient histogram empty for {name}, "
f"iteration {task.iteration}. Unable to "
f"log gradient."
)
if iteration % task.config["LOG_FREQUENCY"] == 0 or (
iteration <= 100 and iteration % 5 == 0
):
logging.info(f"Logging metrics. Iteration {iteration}")
self.tb_writer.add_scalar(
tag="Training/Loss",
scalar_value=round(task.last_batch.loss.data.cpu().item(), 5),
global_step=iteration,
)
self.tb_writer.add_scalar(
tag="Training/Learning_rate",
scalar_value=round(task.optimizer.options_view.lr, 5),
global_step=iteration,
)
# Batch processing time
if len(task.batch_time) > 0:
batch_times = task.batch_time
else:
batch_times = [0]
batch_time_avg_s = sum(batch_times) / max(len(batch_times), 1)
self.tb_writer.add_scalar(
tag="Speed/Batch_processing_time_ms",
scalar_value=int(1000.0 * batch_time_avg_s),
global_step=iteration,
)
# Images per second per replica
pic_per_batch_per_gpu = task.config["DATA"]["TRAIN"][
"BATCHSIZE_PER_REPLICA"
]
pic_per_batch_per_gpu_per_sec = (
int(pic_per_batch_per_gpu / batch_time_avg_s)
if batch_time_avg_s > 0
else 0.0
)
self.tb_writer.add_scalar(
tag="Speed/img_per_sec_per_gpu",
scalar_value=pic_per_batch_per_gpu_per_sec,
global_step=iteration,
)
# ETA
avg_time = sum(batch_times) / len(batch_times)
eta_secs = avg_time * (task.max_iteration - iteration)
self.tb_writer.add_scalar(
tag="Speed/ETA_hours",
scalar_value=eta_secs / 3600.0,
global_step=iteration,
)
# GPU Memory
if torch.cuda.is_available():
# Memory actually being used
self.tb_writer.add_scalar(
tag="Memory/Peak_GPU_Memory_allocated_MiB",
scalar_value=torch.cuda.max_memory_allocated() / BYTE_TO_MiB,
global_step=iteration,
)
# Memory reserved by PyTorch's memory allocator
self.tb_writer.add_scalar(
tag="Memory/Peak_GPU_Memory_reserved_MiB",
scalar_value=torch.cuda.max_memory_reserved()
/ BYTE_TO_MiB, # byte to MiB
global_step=iteration,
)
self.tb_writer.add_scalar(
tag="Memory/Current_GPU_Memory_reserved_MiB",
scalar_value=torch.cuda.memory_reserved()
/ BYTE_TO_MiB, # byte to MiB
global_step=iteration,
)