Skip to content

Commit

Permalink
fixed hard-coded device
Browse files Browse the repository at this point in the history
  • Loading branch information
mravanelli committed Aug 14, 2023
1 parent 331acdb commit 091b3ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 1 addition & 3 deletions recipes/SLURP/direct/hparams/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

# Models
asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
source: speechbrain/asr-crdnn-rnnlm-librispeech
run_opts: {"device":"cuda:0"}
asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech

slu_enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <ASR_encoder_dim>]
Expand Down
8 changes: 8 additions & 0 deletions recipes/SLURP/direct/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,14 @@ def text_pipeline(semantics):
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected(device=run_opts["device"])

# Download pretrained ASR model
from speechbrain.pretrained import EncoderDecoderASR

hparams["asr_model"] = EncoderDecoderASR.from_hparams(
source=hparams["asr_model_source"],
run_opts={"device": run_opts["device"]},
)

# Brain class initialization
slu_brain = SLU(
modules=hparams["modules"],
Expand Down

0 comments on commit 091b3ce

Please sign in to comment.