Skip to content

Commit

Permalink
fix evaluator compile before each epoch (mindspore-lab#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Dec 30, 2022
1 parent 7b7148a commit cd86d63
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 197 deletions.
56 changes: 23 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,11 @@ from mindspore import ops
from mindspore.common.initializer import Uniform, HeUniform
from mindnlp.abc import Seq2vecModel

class Head(nn.Cell):
"""
Head for Sentiment Classification model
"""
def __init__(self, hidden_dim, output_dim, dropout):
super().__init__()
weight_init = HeUniform(math.sqrt(5))
bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(1 - dropout)

def construct(self, context):
context = ops.concat((context[-2, :, :], context[-1, :, :]), axis=1)
context = self.dropout(context)
return self.sigmoid(self.fc(context))


class SentimentClassification(Seq2vecModel):
"""
Sentiment Classification model
"""
def __init__(self, encoder, head):
super().__init__(encoder, head)
self.encoder = encoder
self.head = head

def construct(self, text):
_, (hidden, _), _ = self.encoder(text)
output = self.head(hidden)
context = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
output = self.head(context)
return output
```

Expand Down Expand Up @@ -140,18 +115,33 @@ from mindnlp.dataset import process

imdb_train = process('imdb', imdb_train, tokenizer=tokenizer, vocab=vocab, \
bucket_boundaries=[400, 500], max_len=600, drop_remainder=True)
imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
imdb_test = process('imdb', imdb_test, tokenizer=tokenizer, vocab=vocab, \
bucket_boundaries=[400, 500], max_len=600, drop_remainder=False)
```

### Instantiate Model
```python
from mindnlp.modules import RNNEncoder

# build encoder
lstm_layer = nn.LSTM(100, hidden_size, num_layers=num_layers, batch_first=True,
dropout=drop, bidirectional=bidirectional)
sentiment_encoder = RNNEncoder(embedding, lstm_layer)
sentiment_head = Head(hidden_size, output_size, drop)
net = SentimentClassification(sentiment_encoder, sentiment_head)
dropout=dropout, bidirectional=bidirectional)
encoder = RNNEncoder(embedding, lstm_layer)

# build head
head = nn.SequentialCell([
nn.Dropout(1 - dropout),
nn.Sigmoid(),
nn.Dense(hidden_size * 2, output_size,
weight_init=HeUniform(math.sqrt(5)),
bias_init=Uniform(1 / math.sqrt(hidden_size * 2)))

])

# build network
network = SentimentClassification(encoder, head)
loss = nn.BCELoss(reduction='mean')
optimizer = nn.Adam(network.trainable_params(), learning_rate=lr)
```

### Training Process
Expand All @@ -166,7 +156,7 @@ metric = Accuracy()
# define trainer
trainer = Trainer(network=net, train_dataset=imdb_train, eval_dataset=imdb_valid, metrics=metric,
epochs=5, loss_fn=loss, optimizer=optimizer)
trainer.run(tgt_columns="label", jit=False)
trainer.run(tgt_columns="label")
print("end train")
```

Expand Down
34 changes: 17 additions & 17 deletions mindnlp/engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class Evaluator:
while evaluating. Default:None.
callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed
while training. Default: None.
jit (bool): Whether use Just-In-Time compile.
"""

def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None):
def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None, jit=False):
self.network = network
self.callbacks = callbacks
self.earlystop = False
Expand All @@ -48,6 +49,16 @@ def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None):
self.total = eval_dataset.get_dataset_size()

self.callback_manager = CallbackManager(callbacks=self.callbacks)
self.eval_func = self._prepare_eval_func(network, jit)

def _prepare_eval_func(self, network, jit):
def _run_step(inputs):
"""Core process of each step."""
outputs = network(*inputs)
return outputs
if jit:
return ms_jit(_run_step)
return _run_step

def _check_metric_type(self, metrics):
"""Check metrics type."""
Expand Down Expand Up @@ -86,41 +97,30 @@ def _check_reuse_dataset(self, dataset):
raise RuntimeError("The dataset object had been used in other model by model.train(...), "
"please create a new dataset.")

def run(self, tgt_columns=None, jit=False):
def run(self, tgt_columns=None):
"""
Evaluating function entry.
Args:
tgt_columns (Optional[list[str], str]): Target label column names for loss function.
jit (bool): Whether use Just-In-Time compile.
"""
args_dict = vars(self)
run_context = RunContext(args_dict)
self.callback_manager.evaluate_begin(run_context)
self.clear_metrics()
_ = self._run(tgt_columns, jit)
_ = self._run(tgt_columns)
self.callback_manager.evaluate_end(run_context)
self.earlystop = getattr(run_context, 'earlystop', False)

def _run(self, tgt_columns=None, jit=False):
def _run(self, tgt_columns=None):
"""Evaluating process for non-data sinking mode. The data would be passed to network directly."""
net = self.network

def _run_step(inputs):
"""Core process of each step."""
outputs = net(*inputs)
return outputs

if jit:
_run_step = ms_jit(_run_step)

net.set_train(False)
self.network.set_train(False)
with tqdm(total=self.total) as progress:
progress.set_description('Evaluate')
for data in self.eval_dataset.create_dict_iterator():
inputs, tgts = self._data_process(data, tgt_columns)
outputs = _run_step(inputs)
outputs = self.eval_func(inputs)
self._update_metrics(outputs, *tgts)
progress.update(1)

Expand Down
144 changes: 64 additions & 80 deletions mindnlp/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
from inspect import signature
from tqdm import tqdm
from mindspore import ops
from mindspore import log, mutable
from mindspore.ops import value_and_grad
from mindnlp import ms_jit
Expand Down Expand Up @@ -48,7 +47,7 @@ class Trainer:
to None and implement calculation of loss in `network`,
then a tuple (data1, data2, data3, ...) with all data returned from dataset will be
passed to the `network`.
metrcis (Optional[list[Metrics], Metrics]): List of metrics objects which should be used
metrics (Optional[list[Metrics], Metrics]): List of metrics objects which should be used
while evaluating. Default:None.
epochs (int): Total number of iterations on the data. Default: 10.
optimizer (Cell): Optimizer for updating the weights. If `optimizer` is None, the `network` needs to
Expand All @@ -57,55 +56,82 @@ class Trainer:
and parallel if needed. Default: None.
callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed
while training. Default: None.
jit (bool): Whether use Just-In-Time compile.
"""

def __init__(self, network=None, train_dataset=None, eval_dataset=None, metrics=None, epochs=10,
loss_fn=None, optimizer=None, callbacks=None):
loss_fn=None, optimizer=None, callbacks=None, jit=False):
self.network = network
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.metrics = metrics
self.epochs = epochs
self.loss_fn = loss_fn
self.optimizer = optimizer
self.callbacks = callbacks

self.cur_epoch_nums = 0
self.cur_step_nums = 0
self.earlystop = False
self.grad_fn = None
if callbacks:
self._prepare_callbacks(callbacks)
self._prepare_eval()
self.callback_manager = CallbackManager(callbacks=self.callbacks)
callbacks = self._prepare_callbacks(callbacks)
self._prepare_eval(eval_dataset, metrics, callbacks, jit)

self.callback_manager = CallbackManager(callbacks)
self.train_fn = self._prepare_train_func(network, loss_fn, optimizer, jit)

def _prepare_train_func(self, network, loss_fn, optimizer, jit):
# forward function
def forward_fn(inputs, labels):
logits_list = ()
logits = network(*inputs)
if isinstance(logits, tuple):
logits_list += logits
else:
logits_list += (logits,)

loss = loss_fn(*logits_list, *labels)
return_list = (loss,) + logits_list
return return_list

def forward_without_loss_fn(inputs, labels):
loss_and_logits = network(*inputs, *labels)
return loss_and_logits

if loss_fn is None:
grad_fn = value_and_grad(forward_without_loss_fn, None, optimizer.parameters, has_aux=True)
else:
grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

def _run_step(inputs, labels):
"""Core process of each step, including the forward propagation process and back propagation of data."""
(loss, _), grads = grad_fn(inputs, labels)
optimizer(grads)
return loss
if jit:
return ms_jit(_run_step)
return _run_step

def _prepare_callbacks(self, callbacks):
self.callbacks = []
if isinstance(callbacks, Callback):
self.callbacks.append(callbacks)
elif isinstance(callbacks, list):
return [callbacks]
if isinstance(callbacks, list):
if all(isinstance(cb, Callback) for cb in callbacks) is True:
self.callbacks = callbacks
else:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
else:
raise TypeError(f"Expect callbacks to be list or Callback. Got {type(callbacks)}.")
return callbacks
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
raise TypeError(f"Expect callbacks to be list or Callback. Got {type(callbacks)}.")

def _check_callbacks_type(self):
for callback in self.callbacks:
def _check_callbacks_type(self, callbacks):
for callback in callbacks:
if isinstance(callback, EarlyStopCallback):
raise ValueError("EarlyStopCallback is not effective when eval_dataset is None.")
if isinstance(callback, BestModelCallback):
raise ValueError("BestModelCallback is not effective when eval_dataset is None.")

def _prepare_eval(self):
if self.eval_dataset is not None and self.metrics is not None:
self.evaluator = Evaluator(network=self.network, eval_dataset=self.eval_dataset, metrics=self.metrics,
callbacks=self.callbacks)
elif self.eval_dataset is None and self.metrics is None:
if self.callbacks:
self._check_callbacks_type()
def _prepare_eval(self, eval_dataset, metrics, callbacks, jit):
if eval_dataset is not None and metrics is not None:
self.evaluator = Evaluator(network=self.network, eval_dataset=eval_dataset, metrics=metrics,
callbacks=callbacks, jit=jit)
elif eval_dataset is None and metrics is None:
if callbacks:
self._check_callbacks_type(callbacks)
self.evaluator = None
else:
raise ValueError("For evaluation in training process, both eval dataset and metrics should be not None.")
Expand All @@ -130,69 +156,30 @@ def _check_reuse_dataset(self, dataset):
raise RuntimeError("The dataset object had been used in other model by model.train(...), "
"please create a new dataset.")

def run(self, tgt_columns=None, jit=False):
def run(self, tgt_columns=None):
"""
Training process entry.
Args:
tgt_columns (Optional[list[str], str]): Target label column names for loss function.
jit (bool): Whether use Just-In-Time compile.
"""

args_dict = vars(self)
run_context = RunContext(args_dict)
self.callback_manager.train_begin(run_context)
self._run(run_context, tgt_columns, jit)
self._run(run_context, tgt_columns)
self.callback_manager.train_end(run_context)

def _run(self, run_context, tgt_columns=None, jit=False):
def _run(self, run_context, tgt_columns=None):
"""
Training process for non-data sinking mode. The data would be passed to network directly.
"""
# forward function
net = self.network

loss_fn = self.loss_fn
optimizer = self.optimizer
def forward_fn(inputs, labels):
logits_list = ()
logits = net(*inputs)
if isinstance(logits, tuple):
logits_list += logits
else:
logits_list += (logits,)

loss = loss_fn(*logits_list, *labels)
return_list = (loss,) + logits_list
return return_list

def forward_without_loss_fn(inputs, labels):
loss_and_logits = net(*inputs, *labels)
return loss_and_logits

if self.loss_fn is None:
grad_fn = value_and_grad(forward_without_loss_fn, None, optimizer.parameters, has_aux=True)
else:
grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

def _run_step(inputs, labels):
"""Core process of each step, including the forward propagation process and back propagation of data."""
(loss, *_), grads = grad_fn(inputs, labels)
optimizer(grads)
return loss

@ms_jit
def _run_step_graph(inputs, labels):
"""Core process of each step, including the forward propagation process and back propagation of data."""
(loss, _), grads = grad_fn(inputs, labels)
loss = ops.depend(loss, optimizer(grads))
return loss

total = self.train_dataset.get_dataset_size()
# train epoch begin
for epoch in range(0, self.epochs):
net.set_train()
self.network.set_train()
self.cur_epoch_nums = epoch + 1
self.cur_step_nums = 0
run_context.cur_epoch_nums = self.cur_epoch_nums
Expand All @@ -209,10 +196,7 @@ def _run_step_graph(inputs, labels):
run_context.cur_step_nums += 1
self.cur_step_nums += 1
self.callback_manager.train_step_begin(run_context)
if jit:
loss = _run_step_graph(inputs, tgts)
else:
loss = _run_step(inputs, tgts)
loss = self.train_fn(inputs, tgts)
loss_total += loss
progress.set_postfix(loss=loss_total/self.cur_step_nums)
progress.update(1)
Expand All @@ -223,7 +207,7 @@ def _run_step_graph(inputs, labels):
self.callback_manager.train_epoch_end(run_context)
# do epoch evaluation
if self.evaluator is not None:
self._do_eval_epoch(run_context, tgt_columns, jit)
self._do_eval_epoch(run_context, tgt_columns)

def _run_ds_sink(self, train_dataset, eval_dataset, list_callback,
cb_params, print_steps, eval_steps):
Expand All @@ -242,11 +226,11 @@ def _do_eval_steps(self, steps, eval_dataset):
"""Evaluate the model after n steps."""
raise NotImplementedError

def _do_eval_epoch(self, run_context, tgt_columns=None, jit=False):
def _do_eval_epoch(self, run_context, tgt_columns=None):
"""Evaluate the model after an epoch."""
self.callback_manager.evaluate_begin(run_context)
self.evaluator.clear_metrics()
metrics_result, metrics_names, metrics_values = self.evaluator._run(tgt_columns, jit)
metrics_result, metrics_names, metrics_values = self.evaluator._run(tgt_columns)
setattr(run_context, "metrics_values", metrics_values)
setattr(run_context, "metrics_result", metrics_result)
setattr(run_context, "metrics_names", metrics_names)
Expand Down
Loading

0 comments on commit cd86d63

Please sign in to comment.