-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
100 lines (88 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import typer
from typing import Optional
from torchvision.transforms import ToTensor, Compose, Lambda, RandomHorizontalFlip
from torch.utils.data import DataLoader
from dataset import CelebADataset, CIFAR10Dataset
from diffusion_utils import DiffusionUtils
from pl_utils import PLModel, ImageGenerationCallback
from lightning.pytorch.loggers import TensorBoardLogger
from config import Config
import lightning as L
def main(
config_path: str,
continue_training: bool = False,
checkpoint_path: Optional[str] = None
):
# Load configuration
config = Config.from_yaml(config_path)
# Setup paths
config.training.results_folder.mkdir(exist_ok=True)
# Setup transforms
transform = Compose([
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: (x * 2) - 1)
])
# Setup dataset based on config
if config.dataset.name.lower() == 'celeba':
train_dataset = CelebADataset(timesteps=config.model.timesteps, transform=transform, train=True)
val_dataset = CelebADataset(timesteps=config.model.timesteps, transform=transform, train=False)
elif config.dataset.name.lower() == 'cifar10':
train_dataset = CIFAR10Dataset(transform=transform, train=True)
val_dataset = CIFAR10Dataset(transform=transform, train=False)
else:
raise ValueError(f"Unknown dataset: {config.dataset.name}")
# Setup data loaders
train_dataloader = DataLoader(
train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
num_workers=config.training.num_workers
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.training.batch_size,
shuffle=False,
num_workers=config.training.num_workers
) if val_dataset else None
# Setup model and training
diffusion_utils = DiffusionUtils(config.model.timesteps)
model = PLModel(
image_size=config.model.image_size,
first_layer_channels=config.model.first_layer_channels,
channels_multiplier=config.model.channels_multiplier,
num_res_blocks=config.model.num_res_blocks,
attn_resolutions=config.model.attn_resolutions,
dropout=config.model.dropout,
learning_rate=config.training.learning_rate,
warmup_steps=config.training.warmup_steps,
diffusion_utils=diffusion_utils
)
# Setup logging and callbacks
logger = TensorBoardLogger(save_dir=str(config.training.results_folder))
image_generation_callback = ImageGenerationCallback(
config.training.view_sample_size,
config.model.image_size,
config.training.log_interval,
diffusion_utils
)
# Setup trainer
trainer = L.Trainer(
max_steps=config.training.max_steps,
accelerator=config.training.accelerator,
devices=config.training.num_gpus,
default_root_dir=str(config.training.results_folder),
log_every_n_steps=config.training.log_interval,
logger=logger,
callbacks=[image_generation_callback],
gradient_clip_val=config.training.gradient_clip_val,
)
# Train
trainer.fit(
model,
train_dataloader,
val_dataloader,
ckpt_path=checkpoint_path if continue_training else None
)
if __name__ == "__main__":
typer.run(main)