From 03cc9eae9db1c0ac6eee5f45fbb0dd6e5b76db9f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 12:26:11 +0800 Subject: [PATCH 1/4] Save batch to disk on OOM. --- .../ASR/pruned_transducer_stateless2/train.py | 77 ++++++++++++++----- 1 file changed, 58 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d15c443882..d3ea042561 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -670,25 +670,30 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + display_and_save_batch(batch, params=params, sp=sp) + raise if params.print_diagnostics and batch_idx == 5: return @@ -933,6 +938,39 @@ def remove_short_and_long_utt(c: Cut): cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + cuts = supervisions["cut"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, @@ -973,6 +1011,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params, sp=sp) raise From ce885c6a679cd2a88b84d603c1c0bc42110085a3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 12:35:41 +0800 Subject: [PATCH 2/4] minor fixes --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d3ea042561..bde297ae43 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -156,15 +156,16 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to " + "be changed.", ) parser.add_argument( "--lr-batches", type=float, default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", ) parser.add_argument( @@ -962,7 +963,6 @@ def display_and_save_batch( supervisions = batch["supervisions"] features = batch["inputs"] - cuts = supervisions["cut"] logging.info(f"features shape: {features.shape}") From a0dbfba77dd64bfb4c2980ffa2283a6eb1050e79 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 12:47:24 +0800 Subject: [PATCH 3/4] Fixes after review. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index bde297ae43..58604ab43a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -691,9 +691,8 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - display_and_save_batch(batch, params=params, sp=sp) + except: + display_and_save_batch(batch, params=params, sp=sp) raise if params.print_diagnostics and batch_idx == 5: From c0ea5808bbaf88055d42c4bc13da1901b48ca89b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 12:53:52 +0800 Subject: [PATCH 4/4] Fix style issues. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 58604ab43a..7a4b03cce7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -691,7 +691,7 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() - except: + except: # noqa display_and_save_batch(batch, params=params, sp=sp) raise