From 17ce5f8688a872fba7a5fd0b70fef794d5410d98 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 6 Oct 2024 16:53:34 +0800 Subject: [PATCH] add more audio format support --- interface/interface_audio.py | 15 +++++++-------- interface/utils.py | 26 +++++++++++++++++++++----- requirements.txt | 3 ++- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/interface/interface_audio.py b/interface/interface_audio.py index e438ec9..d810cc8 100644 --- a/interface/interface_audio.py +++ b/interface/interface_audio.py @@ -32,7 +32,7 @@ face_pts_mean = adjust_verts(face_pts_mean) teeth_verts_ = render_verts_[478:, :3] head_joint = np.array([out_size * 0.5, out_size * 3 / 4, -0.]) -def run_audio(img_path, wavpath, output_path, template_path = None): +def run_audio(img_path, audio_path, output_path, template_path = None): img_primer_rgba, source_img, source_crop_pts, source_crop_pts_vt, source_crop_coords = face_process(img_path, out_size) # print(source_img.shape) @@ -57,7 +57,7 @@ def run_audio(img_path, wavpath, output_path, template_path = None): tensor_source_prompt = torch.from_numpy(source_prompt / 255.).float().permute(2, 0, 1).unsqueeze(0).to( device) - pts_audio_driving = audio_interface(wavpath) + pts_audio_driving = audio_interface(audio_path) frame_num = len(pts_audio_driving) import uuid task_id = str(uuid.uuid1()) @@ -183,27 +183,26 @@ def run_audio(img_path, wavpath, output_path, template_path = None): videoWriter.write(frame[..., ::-1]) videoWriter.release() val_video = output_path - wav_path = wavpath os.system( - "ffmpeg -i {} -i {} -c:v libx264 -pix_fmt yuv420p {}".format(save_path, wav_path, val_video)) + "ffmpeg -i {} -i {} -c:v libx264 -pix_fmt yuv420p {}".format(save_path, audio_path, val_video)) os.remove(save_path) cv2.destroyAllWindows() def main(): # 检查命令行参数的数量 if len(sys.argv) < 4 or len(sys.argv) > 5: - print("Usage: python interface_audio.py ") + print("Usage: python interface_audio.py ") sys.exit(1) # 参数数量不正确时退出程序 img_path = sys.argv[1] - wav_path = sys.argv[2] + audio_path = sys.argv[2] output_path = sys.argv[3] if len(sys.argv) == 4: template_path = None else: template_path = sys.argv[4] - print(f"img path is set to: {img_path}, wav path is set to: {wav_path}, output path is set to: {output_path}") - run_audio(img_path, wav_path, output_path, template_path) + print(f"img path is set to: {img_path}, wav path is set to: {audio_path}, output path is set to: {output_path}") + run_audio(img_path, audio_path, output_path, template_path) if __name__ == "__main__": main() diff --git a/interface/utils.py b/interface/utils.py index 203c5ab..a27ea4f 100644 --- a/interface/utils.py +++ b/interface/utils.py @@ -112,8 +112,23 @@ def rgb_face_process(img_primer_bgr, out_size): # mat_list, _, face_pts_mean_personal_primer = calc_face_mat(pts_driven, face_pts_mean) # return source_img, source_crop_pts +# 读取音频文件 +def load_audio(file_path): + import librosa + import numpy as np -def audio_interface(wavpath): + # 使用 librosa 读取音频文件 + # sr=None 表示不改变原始采样率,mono=True 表示转换为单声道 + y, sr = librosa.load(file_path, sr=None, mono=True) + + # 将采样率转换为 16kHz + y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000) + + # 确保数据类型为 float32 + y_16k = y_16k.astype(np.float32) + return y_16k + +def audio_interface(audio_path): global Audio2FeatureModel,PcaModel if Audio2FeatureModel is None: current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -124,10 +139,11 @@ def audio_interface(wavpath): Audio2FeatureModel.load_state_dict(torch.load(ckpt_path)) Audio2FeatureModel = Audio2FeatureModel.to(device) Audio2FeatureModel.eval() - rate, wav = wavfile.read(wavpath, mmap=False) - - augmented_samples = wav - augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 + # rate, wav = wavfile.read(wavpath, mmap=False) + # + # augmented_samples = wav + # augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 + augmented_samples2 = load_audio(audio_path) opts = knf.FbankOptions() opts.frame_opts.dither = 0 diff --git a/requirements.txt b/requirements.txt index 34c40cf..d60a36e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ tqdm scikit-learn glfw PyOpenGL -onnxruntime \ No newline at end of file +onnxruntime +librosa \ No newline at end of file