Skip to content

Commit

Permalink
fix wenet stateless5 jit export error (k2-fsa#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cescfangs authored Dec 5, 2022
1 parent bd7fa22 commit be6e08f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from pathlib import Path

import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
Expand Down Expand Up @@ -184,6 +185,7 @@ def main():
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
convert_scaled_to_non_scaled(model, inplace=True)
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
Expand Down
1 change: 1 addition & 0 deletions egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py

0 comments on commit be6e08f

Please sign in to comment.