Skip to content

Commit

Permalink
Trainer.run() set GRAPH_MODE to avoid mindspore control flow issue (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored May 28, 2023
1 parent def5466 commit 69fc70d
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions mindnlp/engine/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from inspect import signature
from tqdm.autonotebook import tqdm
from mindspore import nn, Tensor
from mindspore import log, mutable
from mindspore import log, mutable, context
from mindspore.dataset.engine import Dataset, TakeDataset
from mindnlp.abc import Callback, Metric
from mindnlp.engine.callbacks.callback_manager import CallbackManager, RunContext
Expand Down Expand Up @@ -188,11 +188,10 @@ def run(self, tgt_columns=None):

self._prepare_train_func()


args_dict = vars(self)

run_context = RunContext(args_dict)
self.callback_manager.train_begin(run_context)

self._run(run_context, tgt_columns)
self.callback_manager.train_end(run_context)

Expand All @@ -201,6 +200,11 @@ def _run(self, run_context, tgt_columns=None):
Training process for non-data sinking mode. The data would be passed to network directly.
"""

# set mindspore mode to GRAPH_MODE, since jit mode with
# control flow will slow down the training speed.
if self.jit:
context.set_context(mode=context.GRAPH_MODE)

total = self.train_dataset.get_dataset_size()
# train epoch begin
for epoch in range(0, self.epochs):
Expand Down Expand Up @@ -238,6 +242,9 @@ def _run(self, run_context, tgt_columns=None):
if self.evaluator is not None:
self._do_eval_epoch(run_context, tgt_columns)

# restore PYNATIVE_MODE after training.
context.set_context(mode=context.PYNATIVE_MODE)

def _run_ds_sink(self, train_dataset, eval_dataset, list_callback,
cb_params, print_steps, eval_steps):
"""Training process for data sinking mode."""
Expand Down

0 comments on commit 69fc70d

Please sign in to comment.