-
Notifications
You must be signed in to change notification settings - Fork 50
/
infer_mos4.py
132 lines (106 loc) · 4.33 KB
/
infer_mos4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import shutil
import os
import argparse
import yaml
import torch
from audioldm_train.utilities.data.dataset_original_mos4 import AudioDataset as AudioDataset
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from audioldm_train.utilities.tools import get_restore_step
from audioldm_train.utilities.model_util import instantiate_from_config
from audioldm_train.utilities.tools import build_dataset_json_from_list
def infer(dataset_key, configs, config_yaml_path, exp_group_name, exp_name):
seed_everything(0)
if "precision" in configs.keys():
torch.set_float32_matmul_precision(configs["precision"])
log_path = configs["log_directory"]
if "dataloader_add_ons" in configs["data"].keys():
dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
dataloader_add_ons = []
val_dataset = AudioDataset(
configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_key
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
)
try:
config_reload_from_ckpt = configs["reload_from_ckpt"]
except:
config_reload_from_ckpt = None
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
wandb_path = os.path.join(log_path, exp_group_name, exp_name)
os.makedirs(checkpoint_path, exist_ok=True)
shutil.copy(config_yaml_path, wandb_path)
# /disk1/changli/jiqun_training_checkpoints/checkpoints/
if len(os.listdir(checkpoint_path)) > 0:
print("Load checkpoint from path: %s" % checkpoint_path)
restore_step, n_step = get_restore_step(checkpoint_path)
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
print("Resume from checkpoint", resume_from_checkpoint)
elif config_reload_from_ckpt is not None:
resume_from_checkpoint = config_reload_from_ckpt
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)
else:
print("Train from scratch")
resume_from_checkpoint = None
latent_diffusion = instantiate_from_config(configs["model"])
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name)
guidance_scale = configs["model"]["params"]["evaluation_params"][
"unconditional_guidance_scale"
]
ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][
"ddim_sampling_steps"
]
n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][
"n_candidates_per_samples"
]
# resume_from_checkpoint = ""
checkpoint = torch.load(resume_from_checkpoint)
latent_diffusion.load_state_dict(checkpoint["state_dict"],strict=False)
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
latent_diffusion.generate_sample(
val_loader,
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_sampling_steps,
n_gen=n_candidates_per_samples,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_yaml",
type=str,
required=False,
help="path to config .yaml file",
)
parser.add_argument(
"-l",
"--list_inference",
type=str,
required=False,
help="The filelist that contain captions (and optionally filenames)",
)
parser.add_argument(
"-reload_from_ckpt",
"--reload_from_ckpt",
type=str,
required=False,
default=None,
help="the checkpoint path for the model",
)
args = parser.parse_args()
# import pdb
# pdb.set_trace()
assert torch.cuda.is_available(), "CUDA is not available"
config_yaml = args.config_yaml
dataset_key = build_dataset_json_from_list(args.list_inference)
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))
config_yaml_path = os.path.join(config_yaml)
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)
if args.reload_from_ckpt is not None:
config_yaml["reload_from_ckpt"] = args.reload_from_ckpt
infer(dataset_key, config_yaml, config_yaml_path, exp_group_name, exp_name)