-
Notifications
You must be signed in to change notification settings - Fork 123
/
base_dash_logger.py
89 lines (66 loc) · 2.66 KB
/
base_dash_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from abc import ABC, abstractmethod
from typing import Dict, Union
from trainer.io import save_fsspec
from trainer.utils.distributed import rank_zero_only
# pylint: disable=too-many-public-methods
class BaseDashboardLogger(ABC):
@abstractmethod
def add_scalar(self, title: str, value: float, step: int) -> None:
pass
@abstractmethod
def add_figure(
self,
title: str,
figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"],
step: int,
) -> None:
pass
@abstractmethod
def add_config(self, config):
pass
@abstractmethod
def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None:
pass
@abstractmethod
def add_text(self, title: str, text: str, step: int) -> None:
pass
@abstractmethod
def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None):
pass
@abstractmethod
def add_scalars(self, scope_name: str, scalars: Dict, step: int):
pass
@abstractmethod
def add_figures(self, scope_name: str, figures: Dict, step: int):
pass
@abstractmethod
def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int):
pass
@abstractmethod
def flush(self):
pass
@abstractmethod
def finish(self):
pass
@staticmethod
@rank_zero_only
def save_model(state: Dict, path: str):
save_fsspec(state, path)
def train_step_stats(self, step, stats):
self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step)
def train_epoch_stats(self, step, stats):
self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step)
def train_figures(self, step, figures):
self.add_figures(scope_name="TrainFigures", figures=figures, step=step)
def train_audios(self, step, audios, sample_rate):
self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate)
def eval_stats(self, step, stats):
self.add_scalars(scope_name="EvalStats", scalars=stats, step=step)
def eval_figures(self, step, figures):
self.add_figures(scope_name="EvalFigures", figures=figures, step=step)
def eval_audios(self, step, audios, sample_rate):
self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate)
def test_audios(self, step, audios, sample_rate):
self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate)
def test_figures(self, step, figures):
self.add_figures(scope_name="TestFigures", figures=figures, step=step)