Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ PyTorch Recognition Model Multi-GPU support #1164

Merged
merged 10 commits into from
Mar 31, 2023
Prev Previous commit
Next Next commit
Update
  • Loading branch information
odulcy-mindee committed Mar 9, 2023
commit f4965929f077a911fece862098d17e781f816943
93 changes: 8 additions & 85 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,88 +35,13 @@
from torch.nn.parallel import DistributedDataParallel as DDP


def record_lr(
model: torch.nn.Module,
rank: int,
train_loader: DataLoader,
batch_transforms,
optimizer,
start_lr: float = 1e-7,
end_lr: float = 1,
num_it: int = 100,
amp: bool = False,
):
"""Gridsearch the optimal learning rate for the training.
Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py
"""

if num_it > len(train_loader):
raise ValueError("the value of `num_it` needs to be lower than the number of available batches")

model = model.train()
# Update param groups & LR
optimizer.defaults["lr"] = start_lr
for pgroup in optimizer.param_groups:
pgroup["lr"] = start_lr

gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
scheduler = MultiplicativeLR(optimizer, lambda step: gamma)

lr_recorder = [start_lr * gamma**idx for idx in range(num_it)]
loss_recorder = []

if amp:
scaler = torch.cuda.amp.GradScaler()

for batch_idx, (images, targets) in enumerate(train_loader):
if torch.cuda.is_available():
images = images.to(rank)

images = batch_transforms(images)

# Forward, Backward & update
optimizer.zero_grad()
if amp:
with torch.cuda.amp.autocast():
train_loss = model(images, targets)["loss"]
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(images, targets)["loss"]
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
# Update LR
scheduler.step()

# Record
if not torch.isfinite(train_loss):
if batch_idx == 0:
raise ValueError("loss value is NaN or inf.")
else:
break
loss_recorder.append(train_loss.item())
# Stop after the number of iterations
if batch_idx + 1 == num_it:
break

return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, rank, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_loader, parent=mb):
#if torch.cuda.is_available():
# images = images.cuda()
images = images.to(rank)
images = batch_transforms(images)

Expand Down Expand Up @@ -153,8 +78,7 @@ def evaluate(model, rank, val_loader, batch_transforms, val_metric, amp=False):
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in val_loader:
if torch.cuda.is_available():
images = images.cuda()
images = images.to(rank)
images = batch_transforms(images)
if amp:
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -350,11 +274,6 @@ def main(rank: int, world_size: int, args):
eps=1e-6,
weight_decay=args.weight_decay,
)
# LR Finder
if args.find_lr:
lrs, losses = record_lr(model, rank, train_loader, batch_transforms, optimizer, amp=args.amp)
plot_recorder(lrs, losses)
return
# Scheduler
if args.sched == "cosine":
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
Expand Down Expand Up @@ -398,7 +317,8 @@ def main(rank: int, world_size: int, args):
val_loss, exact_match, partial_match = evaluate(model, rank, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
# FIXME
#torch.save(model.state_dict(), f"./{exp_name}.pt")
min_loss = val_loss
mb.write(
f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
Expand Down Expand Up @@ -473,7 +393,6 @@ def parse_args():
)
parser.add_argument("--sched", type=str, default="cosine", help="scheduler to use")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
args = parser.parse_args()

return args
Expand All @@ -482,7 +401,11 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
nprocs = 2
# nprocs = World size ici
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main,
args=(nprocs, args),
nprocs=nprocs,
Expand Down