diff --git a/mindnlp/engine/trainer/base.py b/mindnlp/engine/trainer/base.py index 76e0d80f7..6bc910ab3 100644 --- a/mindnlp/engine/trainer/base.py +++ b/mindnlp/engine/trainer/base.py @@ -21,7 +21,7 @@ from inspect import signature from tqdm.autonotebook import tqdm from mindspore import nn, Tensor -from mindspore import log, mutable +from mindspore import log, mutable, context from mindspore.dataset.engine import Dataset, TakeDataset from mindnlp.abc import Callback, Metric from mindnlp.engine.callbacks.callback_manager import CallbackManager, RunContext @@ -188,11 +188,10 @@ def run(self, tgt_columns=None): self._prepare_train_func() - args_dict = vars(self) - run_context = RunContext(args_dict) self.callback_manager.train_begin(run_context) + self._run(run_context, tgt_columns) self.callback_manager.train_end(run_context) @@ -201,6 +200,11 @@ def _run(self, run_context, tgt_columns=None): Training process for non-data sinking mode. The data would be passed to network directly. """ + # set mindspore mode to GRAPH_MODE, since jit mode with + # control flow will slow down the training speed. + if self.jit: + context.set_context(mode=context.GRAPH_MODE) + total = self.train_dataset.get_dataset_size() # train epoch begin for epoch in range(0, self.epochs): @@ -238,6 +242,9 @@ def _run(self, run_context, tgt_columns=None): if self.evaluator is not None: self._do_eval_epoch(run_context, tgt_columns) + # restore PYNATIVE_MODE after training. + context.set_context(mode=context.PYNATIVE_MODE) + def _run_ds_sink(self, train_dataset, eval_dataset, list_callback, cb_params, print_steps, eval_steps): """Training process for data sinking mode."""