forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_paddle_autolog.py
136 lines (110 loc) · 4.68 KB
/
_paddle_autolog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import paddle
import mlflow
from mlflow.utils.autologging_utils import (
ExceptionSafeAbstractClass,
BatchMetricsLogger,
MlflowAutologgingQueueingClient,
get_autologging_config,
)
class __MLflowPaddleCallback(paddle.callbacks.Callback, metaclass=ExceptionSafeAbstractClass):
"""
Callback for auto-logging metrics and parameters.
"""
def __init__(self, client, metrics_logger, run_id, log_models, log_every_n_epoch):
super().__init__()
self.early_stopping = False
self.client = client
self.metrics_logger = metrics_logger
self.run_id = run_id
self.log_models = log_models
self.log_every_n_epoch = log_every_n_epoch
self.epoch = 0
def _log_metrics(self, logs, current_epoch):
metrics = {
key: (metric[0] if isinstance(metric, list) else metric) for key, metric in logs.items()
}
self.metrics_logger.record_metrics(metrics, current_epoch)
def on_epoch_end(self, epoch, logs=None):
if self.model is not None and epoch % self.log_every_n_epoch == 0:
self._log_metrics(logs, epoch)
self.epoch = epoch
def on_train_begin(self, logs=None):
params = {
"optimizer_name": self.model._optimizer.__class__.__name__,
"learning_rate": self.model._optimizer._learning_rate,
}
self.client.log_params(self.run_id, params)
self.client.flush(synchronous=True)
def on_train_end(self, logs=None):
self.metrics_logger.flush()
self.client.flush(synchronous=True)
def on_eval_end(self, logs=None):
eval_logs = {
"eval_" + key: (metric[0] if isinstance(metric, list) else metric)
for key, metric in logs.items()
}
self._log_metrics(eval_logs, self.epoch)
def _log_early_stop_params(early_stop_callback, client, run_id):
"""
Logs early stopping configuration parameters to MLflow.
:param early_stop_callback: The early stopping callback instance used during training.
:param client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
:param run_id: The ID of the MLflow Run to which to log configuration parameters.
"""
client.log_params(
run_id,
{
p: getattr(early_stop_callback, p)
for p in ["monitor", "patience", "min_delta", "baseline"]
if hasattr(early_stop_callback, p)
},
)
def _log_early_stop_metrics(early_stop_callback, client, run_id):
"""
Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow.
:param early_stop_callback: The early stopping callback instance used during training.
:param client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
:param run_id: The ID of the MLflow Run to which to log configuration parameters.
"""
if early_stop_callback.stopped_epoch == 0:
return
metrics = {
"stopped_epoch": early_stop_callback.stopped_epoch,
"best_value": early_stop_callback.best_value,
}
client.log_metrics(run_id, metrics)
def patched_fit(original, self, *args, **kwargs):
run_id = mlflow.active_run().info.run_id
tracking_uri = mlflow.get_tracking_uri()
client = MlflowAutologgingQueueingClient(tracking_uri)
metrics_logger = BatchMetricsLogger(run_id, tracking_uri)
log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_models", True)
log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_every_n_epoch", 1)
early_stop_callback = None
mlflow_callback = __MLflowPaddleCallback(
client, metrics_logger, run_id, log_models, log_every_n_epoch
)
if "callbacks" in kwargs:
callbacks = kwargs["callbacks"]
for callback in callbacks:
if isinstance(callback, paddle.callbacks.EarlyStopping):
early_stop_callback = callback
_log_early_stop_params(early_stop_callback, client, run_id)
break
kwargs["callbacks"].append(mlflow_callback)
else:
kwargs["callbacks"] = [mlflow_callback]
client.flush(synchronous=False)
result = original(self, *args, **kwargs)
if early_stop_callback is not None:
_log_early_stop_metrics(early_stop_callback, client, run_id)
mlflow.log_text(str(self.summary()), "model_summary.txt")
if log_models:
registered_model_name = get_autologging_config(
mlflow.paddle.FLAVOR_NAME, "registered_model_name", None
)
mlflow.paddle.log_model(
pd_model=self, artifact_path="model", registered_model_name=registered_model_name
)
client.flush(synchronous=True)
return result