diff --git a/README.md b/README.md index 90783f6c4..18bc90e33 100644 --- a/README.md +++ b/README.md @@ -70,36 +70,11 @@ from mindspore import ops from mindspore.common.initializer import Uniform, HeUniform from mindnlp.abc import Seq2vecModel -class Head(nn.Cell): - """ - Head for Sentiment Classification model - """ - def __init__(self, hidden_dim, output_dim, dropout): - super().__init__() - weight_init = HeUniform(math.sqrt(5)) - bias_init = Uniform(1 / math.sqrt(hidden_dim * 2)) - self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init) - self.sigmoid = nn.Sigmoid() - self.dropout = nn.Dropout(1 - dropout) - - def construct(self, context): - context = ops.concat((context[-2, :, :], context[-1, :, :]), axis=1) - context = self.dropout(context) - return self.sigmoid(self.fc(context)) - - class SentimentClassification(Seq2vecModel): - """ - Sentiment Classification model - """ - def __init__(self, encoder, head): - super().__init__(encoder, head) - self.encoder = encoder - self.head = head - def construct(self, text): _, (hidden, _), _ = self.encoder(text) - output = self.head(hidden) + context = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1) + output = self.head(context) return output ``` @@ -140,18 +115,33 @@ from mindnlp.dataset import process imdb_train = process('imdb', imdb_train, tokenizer=tokenizer, vocab=vocab, \ bucket_boundaries=[400, 500], max_len=600, drop_remainder=True) -imdb_train, imdb_valid = imdb_train.split([0.7, 0.3]) +imdb_test = process('imdb', imdb_test, tokenizer=tokenizer, vocab=vocab, \ + bucket_boundaries=[400, 500], max_len=600, drop_remainder=False) ``` ### Instantiate Model ```python from mindnlp.modules import RNNEncoder +# build encoder lstm_layer = nn.LSTM(100, hidden_size, num_layers=num_layers, batch_first=True, - dropout=drop, bidirectional=bidirectional) -sentiment_encoder = RNNEncoder(embedding, lstm_layer) -sentiment_head = Head(hidden_size, output_size, drop) -net = SentimentClassification(sentiment_encoder, sentiment_head) + dropout=dropout, bidirectional=bidirectional) +encoder = RNNEncoder(embedding, lstm_layer) + +# build head +head = nn.SequentialCell([ + nn.Dropout(1 - dropout), + nn.Sigmoid(), + nn.Dense(hidden_size * 2, output_size, + weight_init=HeUniform(math.sqrt(5)), + bias_init=Uniform(1 / math.sqrt(hidden_size * 2))) + +]) + +# build network +network = SentimentClassification(encoder, head) +loss = nn.BCELoss(reduction='mean') +optimizer = nn.Adam(network.trainable_params(), learning_rate=lr) ``` ### Training Process @@ -166,7 +156,7 @@ metric = Accuracy() # define trainer trainer = Trainer(network=net, train_dataset=imdb_train, eval_dataset=imdb_valid, metrics=metric, epochs=5, loss_fn=loss, optimizer=optimizer) -trainer.run(tgt_columns="label", jit=False) +trainer.run(tgt_columns="label") print("end train") ``` diff --git a/mindnlp/engine/evaluator.py b/mindnlp/engine/evaluator.py index b72d68306..fd6a3e9ea 100644 --- a/mindnlp/engine/evaluator.py +++ b/mindnlp/engine/evaluator.py @@ -36,9 +36,10 @@ class Evaluator: while evaluating. Default:None. callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed while training. Default: None. + jit (bool): Whether use Just-In-Time compile. """ - def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None): + def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None, jit=False): self.network = network self.callbacks = callbacks self.earlystop = False @@ -48,6 +49,16 @@ def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None): self.total = eval_dataset.get_dataset_size() self.callback_manager = CallbackManager(callbacks=self.callbacks) + self.eval_func = self._prepare_eval_func(network, jit) + + def _prepare_eval_func(self, network, jit): + def _run_step(inputs): + """Core process of each step.""" + outputs = network(*inputs) + return outputs + if jit: + return ms_jit(_run_step) + return _run_step def _check_metric_type(self, metrics): """Check metrics type.""" @@ -86,41 +97,30 @@ def _check_reuse_dataset(self, dataset): raise RuntimeError("The dataset object had been used in other model by model.train(...), " "please create a new dataset.") - def run(self, tgt_columns=None, jit=False): + def run(self, tgt_columns=None): """ Evaluating function entry. Args: tgt_columns (Optional[list[str], str]): Target label column names for loss function. - jit (bool): Whether use Just-In-Time compile. """ args_dict = vars(self) run_context = RunContext(args_dict) self.callback_manager.evaluate_begin(run_context) self.clear_metrics() - _ = self._run(tgt_columns, jit) + _ = self._run(tgt_columns) self.callback_manager.evaluate_end(run_context) self.earlystop = getattr(run_context, 'earlystop', False) - def _run(self, tgt_columns=None, jit=False): + def _run(self, tgt_columns=None): """Evaluating process for non-data sinking mode. The data would be passed to network directly.""" - net = self.network - - def _run_step(inputs): - """Core process of each step.""" - outputs = net(*inputs) - return outputs - - if jit: - _run_step = ms_jit(_run_step) - - net.set_train(False) + self.network.set_train(False) with tqdm(total=self.total) as progress: progress.set_description('Evaluate') for data in self.eval_dataset.create_dict_iterator(): inputs, tgts = self._data_process(data, tgt_columns) - outputs = _run_step(inputs) + outputs = self.eval_func(inputs) self._update_metrics(outputs, *tgts) progress.update(1) diff --git a/mindnlp/engine/trainer.py b/mindnlp/engine/trainer.py index e606a89e2..e4184acae 100644 --- a/mindnlp/engine/trainer.py +++ b/mindnlp/engine/trainer.py @@ -19,7 +19,6 @@ """ from inspect import signature from tqdm import tqdm -from mindspore import ops from mindspore import log, mutable from mindspore.ops import value_and_grad from mindnlp import ms_jit @@ -48,7 +47,7 @@ class Trainer: to None and implement calculation of loss in `network`, then a tuple (data1, data2, data3, ...) with all data returned from dataset will be passed to the `network`. - metrcis (Optional[list[Metrics], Metrics]): List of metrics objects which should be used + metrics (Optional[list[Metrics], Metrics]): List of metrics objects which should be used while evaluating. Default:None. epochs (int): Total number of iterations on the data. Default: 10. optimizer (Cell): Optimizer for updating the weights. If `optimizer` is None, the `network` needs to @@ -57,55 +56,82 @@ class Trainer: and parallel if needed. Default: None. callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed while training. Default: None. + jit (bool): Whether use Just-In-Time compile. """ def __init__(self, network=None, train_dataset=None, eval_dataset=None, metrics=None, epochs=10, - loss_fn=None, optimizer=None, callbacks=None): + loss_fn=None, optimizer=None, callbacks=None, jit=False): self.network = network self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.metrics = metrics self.epochs = epochs - self.loss_fn = loss_fn - self.optimizer = optimizer - self.callbacks = callbacks + self.cur_epoch_nums = 0 self.cur_step_nums = 0 self.earlystop = False - self.grad_fn = None if callbacks: - self._prepare_callbacks(callbacks) - self._prepare_eval() - self.callback_manager = CallbackManager(callbacks=self.callbacks) + callbacks = self._prepare_callbacks(callbacks) + self._prepare_eval(eval_dataset, metrics, callbacks, jit) + + self.callback_manager = CallbackManager(callbacks) + self.train_fn = self._prepare_train_func(network, loss_fn, optimizer, jit) + + def _prepare_train_func(self, network, loss_fn, optimizer, jit): + # forward function + def forward_fn(inputs, labels): + logits_list = () + logits = network(*inputs) + if isinstance(logits, tuple): + logits_list += logits + else: + logits_list += (logits,) + + loss = loss_fn(*logits_list, *labels) + return_list = (loss,) + logits_list + return return_list + + def forward_without_loss_fn(inputs, labels): + loss_and_logits = network(*inputs, *labels) + return loss_and_logits + + if loss_fn is None: + grad_fn = value_and_grad(forward_without_loss_fn, None, optimizer.parameters, has_aux=True) + else: + grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + + def _run_step(inputs, labels): + """Core process of each step, including the forward propagation process and back propagation of data.""" + (loss, _), grads = grad_fn(inputs, labels) + optimizer(grads) + return loss + if jit: + return ms_jit(_run_step) + return _run_step def _prepare_callbacks(self, callbacks): - self.callbacks = [] if isinstance(callbacks, Callback): - self.callbacks.append(callbacks) - elif isinstance(callbacks, list): + return [callbacks] + if isinstance(callbacks, list): if all(isinstance(cb, Callback) for cb in callbacks) is True: - self.callbacks = callbacks - else: - obj = [not isinstance(cb, Callback) for cb in callbacks][0] - raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") - else: - raise TypeError(f"Expect callbacks to be list or Callback. Got {type(callbacks)}.") + return callbacks + obj = [not isinstance(cb, Callback) for cb in callbacks][0] + raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + raise TypeError(f"Expect callbacks to be list or Callback. Got {type(callbacks)}.") - def _check_callbacks_type(self): - for callback in self.callbacks: + def _check_callbacks_type(self, callbacks): + for callback in callbacks: if isinstance(callback, EarlyStopCallback): raise ValueError("EarlyStopCallback is not effective when eval_dataset is None.") if isinstance(callback, BestModelCallback): raise ValueError("BestModelCallback is not effective when eval_dataset is None.") - def _prepare_eval(self): - if self.eval_dataset is not None and self.metrics is not None: - self.evaluator = Evaluator(network=self.network, eval_dataset=self.eval_dataset, metrics=self.metrics, - callbacks=self.callbacks) - elif self.eval_dataset is None and self.metrics is None: - if self.callbacks: - self._check_callbacks_type() + def _prepare_eval(self, eval_dataset, metrics, callbacks, jit): + if eval_dataset is not None and metrics is not None: + self.evaluator = Evaluator(network=self.network, eval_dataset=eval_dataset, metrics=metrics, + callbacks=callbacks, jit=jit) + elif eval_dataset is None and metrics is None: + if callbacks: + self._check_callbacks_type(callbacks) self.evaluator = None else: raise ValueError("For evaluation in training process, both eval dataset and metrics should be not None.") @@ -130,69 +156,30 @@ def _check_reuse_dataset(self, dataset): raise RuntimeError("The dataset object had been used in other model by model.train(...), " "please create a new dataset.") - def run(self, tgt_columns=None, jit=False): + def run(self, tgt_columns=None): """ Training process entry. Args: tgt_columns (Optional[list[str], str]): Target label column names for loss function. - jit (bool): Whether use Just-In-Time compile. """ args_dict = vars(self) run_context = RunContext(args_dict) self.callback_manager.train_begin(run_context) - self._run(run_context, tgt_columns, jit) + self._run(run_context, tgt_columns) self.callback_manager.train_end(run_context) - def _run(self, run_context, tgt_columns=None, jit=False): + def _run(self, run_context, tgt_columns=None): """ Training process for non-data sinking mode. The data would be passed to network directly. """ - # forward function - net = self.network - - loss_fn = self.loss_fn - optimizer = self.optimizer - def forward_fn(inputs, labels): - logits_list = () - logits = net(*inputs) - if isinstance(logits, tuple): - logits_list += logits - else: - logits_list += (logits,) - - loss = loss_fn(*logits_list, *labels) - return_list = (loss,) + logits_list - return return_list - - def forward_without_loss_fn(inputs, labels): - loss_and_logits = net(*inputs, *labels) - return loss_and_logits - - if self.loss_fn is None: - grad_fn = value_and_grad(forward_without_loss_fn, None, optimizer.parameters, has_aux=True) - else: - grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - - def _run_step(inputs, labels): - """Core process of each step, including the forward propagation process and back propagation of data.""" - (loss, *_), grads = grad_fn(inputs, labels) - optimizer(grads) - return loss - - @ms_jit - def _run_step_graph(inputs, labels): - """Core process of each step, including the forward propagation process and back propagation of data.""" - (loss, _), grads = grad_fn(inputs, labels) - loss = ops.depend(loss, optimizer(grads)) - return loss total = self.train_dataset.get_dataset_size() # train epoch begin for epoch in range(0, self.epochs): - net.set_train() + self.network.set_train() self.cur_epoch_nums = epoch + 1 self.cur_step_nums = 0 run_context.cur_epoch_nums = self.cur_epoch_nums @@ -209,10 +196,7 @@ def _run_step_graph(inputs, labels): run_context.cur_step_nums += 1 self.cur_step_nums += 1 self.callback_manager.train_step_begin(run_context) - if jit: - loss = _run_step_graph(inputs, tgts) - else: - loss = _run_step(inputs, tgts) + loss = self.train_fn(inputs, tgts) loss_total += loss progress.set_postfix(loss=loss_total/self.cur_step_nums) progress.update(1) @@ -223,7 +207,7 @@ def _run_step_graph(inputs, labels): self.callback_manager.train_epoch_end(run_context) # do epoch evaluation if self.evaluator is not None: - self._do_eval_epoch(run_context, tgt_columns, jit) + self._do_eval_epoch(run_context, tgt_columns) def _run_ds_sink(self, train_dataset, eval_dataset, list_callback, cb_params, print_steps, eval_steps): @@ -242,11 +226,11 @@ def _do_eval_steps(self, steps, eval_dataset): """Evaluate the model after n steps.""" raise NotImplementedError - def _do_eval_epoch(self, run_context, tgt_columns=None, jit=False): + def _do_eval_epoch(self, run_context, tgt_columns=None): """Evaluate the model after an epoch.""" self.callback_manager.evaluate_begin(run_context) self.evaluator.clear_metrics() - metrics_result, metrics_names, metrics_values = self.evaluator._run(tgt_columns, jit) + metrics_result, metrics_names, metrics_values = self.evaluator._run(tgt_columns) setattr(run_context, "metrics_values", metrics_values) setattr(run_context, "metrics_result", metrics_result) setattr(run_context, "metrics_names", metrics_names) diff --git a/mindnlp/modules/crf.py b/mindnlp/modules/crf.py index 7b306daf0..2194fd896 100644 --- a/mindnlp/modules/crf.py +++ b/mindnlp/modules/crf.py @@ -62,9 +62,9 @@ class CRF(nn.Cell): """ def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None: + super().__init__() if num_tags <= 0: raise ValueError(f'invalid number of tags: {num_tags}') - super().__init__() if reduction not in ('none', 'sum', 'mean', 'token_mean'): raise ValueError(f'invalid reduction: {reduction}') self.num_tags = num_tags diff --git a/mindnlp/modules/encoder/cnn_encoder.py b/mindnlp/modules/encoder/cnn_encoder.py index c44c14018..de0a61410 100644 --- a/mindnlp/modules/encoder/cnn_encoder.py +++ b/mindnlp/modules/encoder/cnn_encoder.py @@ -115,6 +115,6 @@ def construct(self, src_token, src_length=None, mask=None): maxpool_out += (self.pool(out).squeeze(axis=2),) result = ops.concat(maxpool_out, axis=1) - if self.projection_layer: + if self.projection_layer is not None: result = self.projection_layer(result) return result diff --git a/tests/st/model/test_sentiment_classification.py b/tests/st/model/test_sentiment_classification.py index 229685909..d33de2b35 100644 --- a/tests/st/model/test_sentiment_classification.py +++ b/tests/st/model/test_sentiment_classification.py @@ -30,8 +30,7 @@ from mindnlp.abc import Seq2vecModel from mindnlp.modules import RNNEncoder -from mindnlp.common.metrics import Accuracy -from mindnlp.engine.trainer import Trainer +from mindnlp.engine import Trainer, Accuracy from mindnlp.dataset import load from mindnlp.modules import Glove diff --git a/tests/ut/engine/test_evaluator.py b/tests/ut/engine/test_evaluator.py index 67dc03f5a..a28246704 100644 --- a/tests/ut/engine/test_evaluator.py +++ b/tests/ut/engine/test_evaluator.py @@ -14,10 +14,11 @@ # ============================================================================ """Test Evaluator with Callback function""" # pylint: disable=C0103 +# pylint: disable=W0621 import unittest import numpy as np - +from ddt import ddt, data from mindspore import nn import mindspore.dataset as ds @@ -46,25 +47,23 @@ def construct(self, data): output = self.fc(data) return output +@ddt class TestEvaluatorRun(unittest.TestCase): r""" Test Evaluator Run """ def setUp(self): self.input = None - net = MyModel() + self.net = MyModel() dataset_generator = MyDataset() - metric = Accuracy() - callbacks = [TimerCallback()] + self.metric = Accuracy() + self.callbacks = [TimerCallback()] eval_dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False) - eval_dataset = eval_dataset.batch(10) - self.evaluator = Evaluator(network=net, eval_dataset=eval_dataset, metrics=metric, - callbacks=callbacks) + self.eval_dataset = eval_dataset.batch(10) - def test_evaluator_run(self): + @data(True, False) + def test_evaluator_run(self, jit): """test evaluator run pynative""" - self.evaluator.run(tgt_columns='label') - - def test_evaluator_run_jit(self): - """test evaluator run graph""" - self.evaluator.run(tgt_columns='label', jit=False) + evaluator = Evaluator(network=self.net, eval_dataset=self.eval_dataset, metrics=self.metric, + callbacks=self.callbacks, jit=jit) + evaluator.run(tgt_columns='label') diff --git a/tests/ut/engine/test_trainer.py b/tests/ut/engine/test_trainer.py index 4b487342a..eea93b2fd 100644 --- a/tests/ut/engine/test_trainer.py +++ b/tests/ut/engine/test_trainer.py @@ -92,79 +92,59 @@ def setUp(self): # 4. define metrics self.metric = Accuracy() # 5. define trainer - self.pure_trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, - metrics=self.metric, epochs=2, optimizer=self.optimizer, - loss_fn=self.loss_fn) @data(True, False) def test_pure_trainer(self, jit): """test_pure_trainer""" # 6. trainer run - self.pure_trainer.run(tgt_columns='label', jit=jit) + pure_trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, + metrics=self.metric, epochs=2, optimizer=self.optimizer, + loss_fn=self.loss_fn, jit=jit) + pure_trainer.run(tgt_columns='label') @data(True, False) def test_trainer_timer(self, jit): """test_trainer_timer""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - eval_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - eval_dataset = eval_dataset.batch(4) - trainer = Trainer(network=self.net, train_dataset=train_dataset, eval_dataset=eval_dataset, metrics=self.metric, - epochs=2, optimizer=self.optimizer, loss_fn=self.loss_fn, - callbacks=self.timer_callback_epochs) - trainer.run(tgt_columns='label', jit=jit) + trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, + metrics=self.metric, epochs=2, optimizer=self.optimizer, loss_fn=self.loss_fn, + callbacks=self.timer_callback_epochs, jit=jit) + trainer.run(tgt_columns='label') @data(True, False) def test_trainer_earlystop(self, jit): """test_trainer_earlystop""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - eval_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - eval_dataset = eval_dataset.batch(4) - trainer = Trainer(network=self.net, train_dataset=train_dataset, eval_dataset=eval_dataset, metrics=self.metric, - epochs=6, optimizer=self.optimizer, loss_fn=self.loss_fn, - callbacks=self.earlystop_callback) - trainer.run(tgt_columns='label', jit=jit) + trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, + metrics=self.metric, epochs=6, optimizer=self.optimizer, loss_fn=self.loss_fn, + callbacks=self.earlystop_callback, jit=jit) + trainer.run(tgt_columns='label') @data(True, False) def test_trainer_bestmodel(self, jit): """test_trainer_bestmodel""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - eval_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - eval_dataset = eval_dataset.batch(4) - trainer = Trainer(network=self.net, train_dataset=train_dataset, eval_dataset=eval_dataset, metrics=self.metric, - epochs=4, optimizer=self.optimizer, loss_fn=self.loss_fn, - callbacks=self.bestmodel_callback) - trainer.run(tgt_columns='label', jit=jit) + trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, + metrics=self.metric, epochs=4, optimizer=self.optimizer, loss_fn=self.loss_fn, + callbacks=self.bestmodel_callback, jit=jit) + trainer.run(tgt_columns='label') @data(True, False) def test_trainer_checkpoint(self, jit): """test_trainer_checkpoint""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - eval_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - eval_dataset = eval_dataset.batch(4) - trainer = Trainer(network=self.net, train_dataset=train_dataset, eval_dataset=eval_dataset, metrics=self.metric, - epochs=7, optimizer=self.optimizer, loss_fn=self.loss_fn, - callbacks=self.checkpoint_callback) - trainer.run(tgt_columns='label', jit=jit) + trainer = Trainer(network=self.net, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, + metrics=self.metric, epochs=7, optimizer=self.optimizer, loss_fn=self.loss_fn, + callbacks=self.checkpoint_callback, jit=jit) + trainer.run(tgt_columns='label') - def test_different_model(self): + @data(True, False) + def test_different_model(self, jit): """test_different_model""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - eval_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - eval_dataset = eval_dataset.batch(4) - trainer = Trainer(network=self.net_2, train_dataset=train_dataset, eval_dataset=eval_dataset, + trainer = Trainer(network=self.net_2, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, metrics=self.metric, epochs=2, optimizer=self.optimizer, - loss_fn=self.loss_fn) - trainer.run(tgt_columns='length', jit=True) + loss_fn=self.loss_fn, jit=jit) + trainer.run(tgt_columns='length') - def test_no_eval_in_trainer(self): + @data(True, False) + def test_no_eval_in_trainer(self, jit): """test_eval_in_trainer""" - train_dataset = ds.GeneratorDataset(self.dataset_generator, ["data", "label", "length"], shuffle=False) - train_dataset = train_dataset.batch(4) - trainer = Trainer(network=self.net, train_dataset=train_dataset, epochs=2, - optimizer=self.optimizer, loss_fn=self.loss_fn) - trainer.run(tgt_columns='length', jit=True) + trainer = Trainer(network=self.net, train_dataset=self.train_dataset, epochs=2, + optimizer=self.optimizer, loss_fn=self.loss_fn, jit=jit) + trainer.run(tgt_columns='length')