Skip to content

Commit

Permalink
Add paraformer npu scripts and tiny fix (wenet-e2e#2563)
Browse files Browse the repository at this point in the history
* add paraformer npu scripts and tiny fix

* add npu setup and rnnt scripts
  • Loading branch information
MengqingCao authored Aug 6, 2024
1 parent bc58f83 commit b2f59ef
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 9 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ conda install conda-forge::sox
pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html
```

<details><summary><b>For Ascend NPU users:</b></summary>

- Install CANN: please follow this [link](https://ascend.github.io/docs/sources/ascend/quick_install.html) to install CANN toolkit and kernels.

- Install WeNet with torch-npu dependencies:

``` sh
pip install -e .[torch-npu]
```

- Related version control table:

| Requirement | Minimum | Recommend |
| ------------ | ---------------- | ----------- |
| CANN | 8.0.RC2.alpha003 | latest |
| torch | 2.1.0 | 2.2.0 |
| torch-npu | 2.1.0 | 2.2.0 |
| torchaudio | 2.1.0 | 2.2.0 |
| deepspeed | 0.13.2 | latest |

</details>

- Install other python packages

``` sh
Expand Down
1 change: 1 addition & 0 deletions docs/python_package.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ You can specify the following parameters.
* `--align`: force align the input audio and transcript
* `--label`: the input label to align
* `--paraformer`: use the best Chinese model
* `--device`: specify the backend accelerator (cuda/npu/cpu)

## Python Programming Usage

Expand Down
179 changes: 179 additions & 0 deletions examples/aishell/paraformer/run_npu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/bin/bash

# Copyright 2019 Mobvoi Inc. All Rights Reserved.
. ./path.sh || exit 1;

# Automatically detect number of npus
if command -v npu-smi info &> /dev/null; then
num_npus=$(npu-smi info -l | grep "Total Count" | awk '{print $4}')
npu_list=$(seq -s, 0 $((num_npus-1)))
else
num_npus=-1
npu_list="-1"
fi

# You can also manually specify ASCEND_RT_VISIBLE_DEVICES
# if you don't want to utilize all available NPU resources.
export ASCEND_RT_VISIBLE_DEVICES="${npu_list}"
echo "ASCEND_RT_VISIBLE_DEVICES is ${ASCEND_RT_VISIBLE_DEVICES}"

stage=0
stop_stage=2

# You should change the following two parameters for multiple machine training,
# see https://pytorch.org/docs/stable/elastic/run.html
HOST_NODE_ADDR="localhost:0"
num_nodes=1
job_id=2024

# data_type can be `raw` or `shard`. Typically, raw is used for small dataset,
# `shard` is used for large dataset which is over 1k hours, and `shard` is
# faster on reading data and training.
data_type=raw

train_set=train

train_config=conf/train_paraformer_dynamic.yaml
checkpoint=exp/paraformer/large/wenet_paraformer.init-ctc.init-embed.pt
dir=exp/finetune_paraformer_dynamic
tensorboard_dir=tensorboard
num_workers=8
prefetch=500

# use average_checkpoint will get better result
average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=5
decode_modes="ctc_greedy_search ctc_prefix_beam_search paraformer_greedy_search"
decode_device=0
decoding_chunk_size=-1
decode_batch=16
ctc_weight=0.3
reverse_weight=0.5
max_epoch=100

train_engine=torch_fsdp

# model+optimizer or model_only, model+optimizer is more time-efficient but
# consumes more space, while model_only is the opposite
deepspeed_config=../whisper/conf/ds_stage1.json
deepspeed_save_states="model+optimizer"

. tools/parse_options.sh || exit 1;

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
mkdir -p $dir
num_npus=$(echo $ASCEND_RT_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "hccl" for npu if it works, otherwise use "gloo"
# NOTE(xcsong): deepspeed fails with gloo, see
# https://github.com/microsoft/DeepSpeed/issues/2818
dist_backend="hccl"

# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
# and export.
echo "$0: using ${train_engine}"

# NOTE(xcsong): Both ddp & deepspeed can be launched by torchrun
# NOTE(xcsong): To unify single-node & multi-node training, we add
# all related args. You should change `nnodes` &
# `rdzv_endpoint` for multi-node, see
# https://pytorch.org/docs/stable/elastic/run.html#usage
# https://github.com/wenet-e2e/wenet/pull/2055#issuecomment-1766055406
# `rdzv_id` - A user-defined id that uniquely identifies the worker group for a job.
# This id is used by each node to join as a member of a particular worker group.
# `rdzv_endpoint` - The rendezvous backend endpoint; usually in form <host>:<port>.
# NOTE(xcsong): In multi-node training, some clusters require special NCCL variables to set prior to training.
# For example: `NCCL_IB_DISABLE=1` + `NCCL_SOCKET_IFNAME=enp` + `NCCL_DEBUG=INFO`
# without NCCL_IB_DISABLE=1
# RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL Version xxx
# without NCCL_SOCKET_IFNAME=enp (IFNAME could be get by `ifconfig`)
# RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:xxx
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
echo "$0: num_nodes is $num_nodes, proc_per_node is $num_npus"
torchrun --nnodes=$num_nodes --nproc_per_node=$num_npus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint=$HOST_NODE_ADDR \
wenet/bin/train.py \
--device "npu" \
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--tensorboard_dir ${tensorboard_dir} \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [ "$deepspeed_save_states" = "model+optimizer" ]; then
for subdir in $(find "$dir" -maxdepth 1 -type d | grep -v "^$dir$")
do
# NOTE(xcsong): zero_to_fp32.py is automatically generated by deepspeed
tag=$(basename "$subdir")
echo "$tag"
python3 ${dir}/zero_to_fp32.py \
${dir} ${dir}/${tag}.pt -t ${tag}
rm -rf ${dir}/${tag}
done
fi
fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# Test model, please specify the model you want to test by --checkpoint
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg_${average_num}_maxepoch_${max_epoch}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--max_epoch ${max_epoch} \
--val_best
fi
# Please specify decoding_chunk_size for unified streaming and
# non-streaming model. The default value is -1, which is full chunk
# for non-streaming inference.
base=$(basename $decode_checkpoint)
result_dir=$dir/${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}
mkdir -p ${result_dir}
python wenet/bin/recognize.py --device "npu" \
--modes $decode_modes \
--config $dir/train.yaml \
--data_type $data_type \
--test_data data/test/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size ${decode_batch} \
--blank_penalty 0.0 \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $result_dir \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
for mode in ${decode_modes}; do
python tools/compute-wer.py --char=1 --v=1 \
data/test/data.list $result_dir/$mode/text > $result_dir/$mode/wer
done
fi


if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Export the best model you want
# NOTE (MengqingCao): if RuntimeError "Expected a value of type 'Tuple[Tensor, Tensor]'
# for argument 'hx' but instead found type 'Tensor (inferred)'." occured,
# modify the function "def lstm_forward(self, input1,hx = None):" to
# "def lstm_forward(self, input1, hx: Optional[tuple[torch.Tensor, torch.Tensor]] = None):"
# in torch-npu/utils/module.py
# revert this note when torch-npu fix it. sa: https://gitee.com/ascend/pytorch/pulls/12818
python wenet/bin/export_jit.py \
--config $dir/train.yaml \
--checkpoint $dir/avg_${average_num}_maxepoch_${max_epoch}.pt \
--output_file $dir/final.zip \
--output_quant_file $dir/final_quant.zip
fi
Loading

0 comments on commit b2f59ef

Please sign in to comment.