Skip to content

Commit

Permalink
enable fp16 for all inference scripts by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Plachtaa committed Nov 28, 2024
1 parent c83ae7a commit ec0b27f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 87 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python inference.py --source <source-wav>
--semi-tone-shift 0 # pitch shift in semitones for singing voice conversion
--checkpoint <path-to-checkpoint>
--config <path-to-config>
--fp16 True
```
where:
- `source` is the path to the speech file to convert to reference voice
Expand All @@ -58,11 +59,11 @@ where:
- `semi-tone-shift` is the pitch shift in semitones for singing voice conversion, default is 0
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface.(`seed-uvit-whisper-small-wavenet` if `f0-condition` is `False` else `seed-uvit-whisper-base`)
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface

- `fp16` is the flag to use float16 inference, default is True

Voice Conversion Web UI:
```bash
python app_vc.py --checkpoint <path-to-checkpoint> --config <path-to-config>
python app_vc.py --checkpoint <path-to-checkpoint> --config <path-to-config> --fp16 True
```
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface. (`seed-uvit-whisper-small-wavenet`)
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface
Expand All @@ -71,7 +72,7 @@ Then open the browser and go to `http://localhost:7860/` to use the web interfac

Singing Voice Conversion Web UI:
```bash
python app_svc.py --checkpoint <path-to-checkpoint> --config <path-to-config>
python app_svc.py --checkpoint <path-to-checkpoint> --config <path-to-config> --fp16 True
```
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface. (`seed-uvit-whisper-base`)
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface
Expand Down
15 changes: 8 additions & 7 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = bigvgan_fn(vc_target)[0]
with torch.autocast(device_type=device.type, dtype=torch.float16):
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = bigvgan_fn(vc_target)[0]
if processed_frames == 0:
if is_last_chunk:
output_wave = vc_wave[0].cpu().numpy()
Expand Down
39 changes: 17 additions & 22 deletions app_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torchaudio
import librosa
from modules.commons import build_model, load_checkpoint, recursive_munch
from modules.commons import build_model, load_checkpoint, recursive_munch, str2bool
import yaml
from hf_utils import load_custom_model_from_hf
import numpy as np
Expand All @@ -13,28 +13,21 @@
# Load model and configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load additional modules
from modules.campplus.DTDNN import CAMPPlus

campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)

from modules.audio import mel_spectrogram

fp16 = False
def load_models(args):
global sr, hop_length
global sr, hop_length, fp16
fp16 = args.fp16
print(f"Using device: {device}")
print(f"Using fp16: {fp16}")
# f0 conditioned model
if args.checkpoint_path is None or args.checkpoint_path == "":
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema_v2.pth",
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
else:
print(f"Using custom checkpoint: {args.checkpoint_path}")
dit_checkpoint_path = args.checkpoint_path
dit_config_path = args.config_path

config = yaml.safe_load(open(dit_config_path, "r"))
model_params = recursive_munch(config["model_params"])
model_params.dit_type = 'DiT'
Expand Down Expand Up @@ -336,13 +329,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
chunk_f0 = interpolated_shifted_f0_alt[:, processed_frames:processed_frames + max_source_window]
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target).squeeze().cpu()
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target).squeeze().cpu()
if vc_wave.ndim == 1:
vc_wave = vc_wave.unsqueeze(0)
if processed_frames == 0:
Expand Down Expand Up @@ -437,6 +431,7 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
parser.add_argument("--share", type=bool, help="Whether to share url link", default=False)
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
args = parser.parse_args()
main(args)
27 changes: 16 additions & 11 deletions app_vc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torchaudio
import librosa
from modules.commons import build_model, load_checkpoint, recursive_munch
from modules.commons import build_model, load_checkpoint, recursive_munch, str2bool
import yaml
from hf_utils import load_custom_model_from_hf
import numpy as np
Expand All @@ -13,9 +13,12 @@

# Load model and configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fp16 = False
def load_models(args):
global sr, hop_length
global sr, hop_length, fp16
fp16 = args.fp16
print(f"Using device: {device}")
print(f"Using fp16: {fp16}")
if args.checkpoint_path is None or args.checkpoint_path == "":
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
Expand Down Expand Up @@ -285,13 +288,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target)[0]
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
# Voice Conversion
vc_target = inference_module.cfm.inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target)[0]
if vc_wave.ndim == 1:
vc_wave = vc_wave.unsqueeze(0)
if processed_frames == 0:
Expand Down Expand Up @@ -380,6 +384,7 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
parser.add_argument("--share", type=bool, help="Whether to share url link", default=False)
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
args = parser.parse_args()
main(args)
33 changes: 18 additions & 15 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

import torchaudio
import librosa
import torchaudio.compliance.kaldi as kaldi
from modules.commons import str2bool

from hf_utils import load_custom_model_from_hf


# Load model and configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fp16 = False
def load_models(args):
global fp16
fp16 = args.fp16
if not args.f0_condition:
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
Expand Down Expand Up @@ -304,17 +306,17 @@ def main(args):
cat_condition = torch.cat([prompt_condition, cond], dim=1)

time_vc_start = time.time()
vc_target = model.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]


# Convert to waveform
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
vc_wave = vc_wave[None, :]
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
vc_target = model.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, mel2.size(-1):]

# Convert to waveform
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
vc_wave = vc_wave[None, :].float()
time_vc_end = time.time()
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")

Expand All @@ -332,10 +334,11 @@ def main(args):
parser.add_argument("--diffusion-steps", type=int, default=30)
parser.add_argument("--length-adjust", type=float, default=1.0)
parser.add_argument("--inference-cfg-rate", type=float, default=0.7)
parser.add_argument("--f0-condition", type=bool, default=False)
parser.add_argument("--auto-f0-adjust", type=bool, default=True)
parser.add_argument("--f0-condition", type=str2bool, default=True)
parser.add_argument("--auto-f0-adjust", type=str2bool, default=True)
parser.add_argument("--semi-tone-shift", type=int, default=0)
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
parser.add_argument("--fp16", type=str2bool, default=True)
args = parser.parse_args()
main(args)
12 changes: 11 additions & 1 deletion modules/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from torch.nn import functional as F
from munch import Munch
import json

import argparse

def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")

class AttrDict(dict):
def __init__(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit ec0b27f

Please sign in to comment.