Skip to content

Commit

Permalink
Update F5-TTS-ONNX-Inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DakeQQ authored Jan 9, 2025
1 parent 156cb1a commit a566154
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions F5-TTS-ONNX-Inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import sys
import time
import jieba
import numpy as np
Expand All @@ -9,31 +8,19 @@
from pydub import AudioSegment
from pypinyin import lazy_pinyin, Style

F5_project_path = "/home/dake/Downloads/F5-TTS-main" # The F5-TTS Github project download path. URL: https://github.com/SWivid/F5-TTS
onnx_model_A = "/home/dake/Downloads/F5_Preprocess.onnx" # The exported onnx model path.
onnx_model_B = "/home/dake/Downloads/F5_Transformer.onnx" # The exported onnx model path.
onnx_model_C = "/home/dake/Downloads/F5_Decode.onnx" # The exported onnx model path.
F5_project_path = "/home/DakeQQ/Downloads/F5-TTS-main" # The F5-TTS Github project download path. URL: https://github.com/SWivid/F5-TTS
onnx_model_A = "/home/DakeQQ/Downloads/F5_Optimized/F5_Preprocess.ort" # The exported onnx model path.
onnx_model_B = "/home/DakeQQ/Downloads/F5_Optimized/F5_Transformer.ort" # The exported onnx model path.
onnx_model_C = "/home/DakeQQ/Downloads/F5_Optimized/F5_Decode.ort" # The exported onnx model path.

reference_audio = "/home/dake/Downloads/F5-TTS-main/src/f5_tts/infer/examples/basic/basic_ref_zh.wav" # The reference audio path.
generated_audio = "/home/dake/Downloads/F5-TTS-main/src/f5_tts/infer/examples/basic/generated.wav" # The generated audio path.
ref_text = "对,这就是我,万人敬仰的太乙真人。" # The ASR result of reference audio.
gen_text = "对,这就是我,万人敬仰的大可奇奇。" # The target TTS.
reference_audio = "/home/DakeQQ/Downloads/F5-TTS-main/src/f5_tts/infer/examples/basic/basic_ref_zh.wav" # The reference audio path.
generated_audio = "/home/DakeQQ/Downloads/F5-TTS-main/src/f5_tts/infer/examples/basic/generated.wav" # The generated audio path.
ref_text = "对,这就是我,万人敬仰的太乙真人。" # The ASR result of reference audio.
gen_text = "对,这就是我,万人敬仰的大可奇奇。" # The target TTS.


ORT_Accelerate_Providers = [] # If you have accelerate devices for : ['CUDAExecutionProvider', 'TensorrtExecutionProvider', 'CoreMLExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'MIGraphXExecutionProvider', 'AzureExecutionProvider']
# else keep empty.
provider_options = []
# For OpenVINOExecutionProvider
# provider_options =
# [{
# 'device_type': 'CPU',
# 'precision': 'ACCURACY',
# 'num_of_threads': 8,
# 'num_streams': 1,
# 'enable_opencl_throttling': True,
# 'enable_qdq_optimizer': True
# }]

HOP_LENGTH = 256 # Number of samples between successive frames in the STFT
SAMPLE_RATE = 24000 # The generated audio sample rate
RANDOM_SEED = 9527 # Set seed to reproduce the generated audio
Expand All @@ -48,6 +35,34 @@
vocab_size = len(vocab_char_map)


if "OpenVINOExecutionProvider" in ORT_Accelerate_Providers:
provider_options = [
{
'device_type': 'CPU',
'precision': 'ACCURACY',
'num_of_threads': 8,
'num_streams': 1,
'enable_opencl_throttling': True,
'enable_qdq_optimizer': True
}
]
elif "CUDAExecutionProvider" in ORT_Accelerate_Providers:
provider_options = [
{
'device_id': 0,
'gpu_mem_limit': 8 * 1024 * 1024 * 1024, # 8 GB
'arena_extend_strategy': 'kNextPowerOfTwo',
'cudnn_conv_algo_search': 'EXHAUSTIVE',
'cudnn_conv_use_max_workspace': '1',
'do_copy_in_default_stream': '1',
'cudnn_conv1d_pad_to_nc1d': '1',
'enable_cuda_graph': '0' # Set to '0' to avoid potential errors when enabled.
}
]
else:
provider_options = None


def is_chinese_char(c):
cp = ord(c)
return (
Expand Down Expand Up @@ -111,18 +126,18 @@ def list_str_to_idx(
# ONNX Runtime settings
onnxruntime.set_seed(RANDOM_SEED)
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3 # error level, it a adjustable value.
session_opts.inter_op_num_threads = 0 # Run different nodes with num_threads. Set 0 for auto.
session_opts.intra_op_num_threads = 0 # Under the node, execute the operators with num_threads. Set 0 for auto.
session_opts.log_severity_level = 3 # error level, it an adjustable value.
session_opts.inter_op_num_threads = 0 # Run different nodes with num_threads. Set 0 for auto.
session_opts.intra_op_num_threads = 0 # Under the node, execute the operators with num_threads. Set 0 for auto.
session_opts.enable_cpu_mem_arena = True # True for execute speed; False for less memory usage.
session_opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
session_opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session_opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
session_opts.add_session_config_entry("session.intra_op.allow_spinning", "1")
session_opts.add_session_config_entry("session.inter_op.allow_spinning", "1")
session_opts.add_session_config_entry("session.set_denormal_as_zero", "1")


ort_session_A = onnxruntime.InferenceSession(onnx_model_A, sess_options=session_opts, providers=['CPUExecutionProvider'])
ort_session_A = onnxruntime.InferenceSession(onnx_model_A, sess_options=session_opts, pproviders=['CPUExecutionProvider'], provider_options=None)
model_type = ort_session_A._inputs_meta[0].type
in_name_A = ort_session_A.get_inputs()
out_name_A = ort_session_A.get_outputs()
Expand All @@ -142,7 +157,7 @@ def list_str_to_idx(
# For DirectML + AMD GPU,
# pip install onnxruntime-directml --upgrade
# ort_session_B = onnxruntime.InferenceSession(onnx_model_B, sess_options=session_opts, providers=['DmlExecutionProvider'])

print(f"\nUsable Providers: {ort_session_B.get_providers()}")
in_name_B = ort_session_B.get_inputs()
out_name_B = ort_session_B.get_outputs()
in_name_B0 = in_name_B[0].name
Expand All @@ -155,7 +170,7 @@ def list_str_to_idx(
out_name_B0 = out_name_B[0].name


ort_session_C = onnxruntime.InferenceSession(onnx_model_C, sess_options=session_opts, providers=['CPUExecutionProvider'])
ort_session_C = onnxruntime.InferenceSession(onnx_model_C, sess_options=session_opts, providers=['CPUExecutionProvider'], provider_options=None)
in_name_C = ort_session_C.get_inputs()
out_name_C = ort_session_C.get_outputs()
in_name_C0 = in_name_C[0].name
Expand Down

0 comments on commit a566154

Please sign in to comment.