Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ PyTorch Recognition Model Multi-GPU support #1164

Merged
merged 10 commits into from
Mar 31, 2023
Prev Previous commit
Next Next commit
Update
  • Loading branch information
odulcy-mindee committed Mar 9, 2023
commit 422d6a687ff07e9c41fa151b62b491e5ae3bffae
44 changes: 18 additions & 26 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from torch.nn.parallel import DistributedDataParallel as DDP


def fit_one_epoch(model, rank, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):
def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_loader, parent=mb):
images = images.to(rank)
images = images.to(device)
images = batch_transforms(images)

train_loss = model(images, targets)["loss"]
Expand Down Expand Up @@ -70,15 +70,15 @@ def fit_one_epoch(model, rank, train_loader, batch_transforms, optimizer, schedu


@torch.no_grad()
def evaluate(model, rank, val_loader, batch_transforms, val_metric, amp=False):
def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in val_loader:
images = images.to(rank)
images = images.to(device)
images = batch_transforms(images)
if amp:
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -171,33 +171,23 @@ def main(rank: int, world_size: int, args):
model.load_state_dict(checkpoint)

# GPU
#if isinstance(args.device, int):
# if not torch.cuda.is_available():
# raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
# if args.device >= torch.cuda.device_count():
# raise ValueError("Invalid device index")
## Silent default switch to GPU if available
#elif torch.cuda.is_available():
# args.device = 0
#else:
# logging.warning("No accessible GPU, targe device set to CPU.")
#if torch.cuda.is_available():
# torch.cuda.set_device(args.device)
# model = model.cuda()
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")

# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
device = torch.cuda.device(args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
# create local model
model = model.to(rank)
model = model.to(device)
# construct DDP model
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[device])

# Metrics
val_metric = TextMatch()
aminemindee marked this conversation as resolved.
Show resolved Hide resolved

if args.test_only:
print("Running evaluation")
val_loss, exact_match, partial_match = evaluate(model, rank, val_loader, batch_transforms, val_metric, amp=args.amp)
val_loss, exact_match, partial_match = evaluate(model, device, val_loader, batch_transforms, val_metric, amp=args.amp)
print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
return

Expand Down Expand Up @@ -311,10 +301,10 @@ def main(rank: int, world_size: int, args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, rank, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)

# Validation loop at the end of each epoch
val_loss, exact_match, partial_match = evaluate(model, rank, val_loader, batch_transforms, val_metric, amp=args.amp)
val_loss, exact_match, partial_match = evaluate(model, device, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
# FIXME
Expand Down Expand Up @@ -372,7 +362,8 @@ def parse_args():
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training")
parser.add_argument("--device", default=None, type=int, help="device")
parser.add_argument("--backend", default='gloo', type=str, help="Backend to use for `torch.distributed.init_process_group`")
parser.add_argument("--devices", default=None, nargs="+", type=int, help="GPU devices to use for training")
parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)")
parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay")
Expand Down Expand Up @@ -400,7 +391,7 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()
nprocs = 2
nprocs = len(args.devices)
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
Expand All @@ -409,4 +400,5 @@ def parse_args():
mp.spawn(main,
args=(nprocs, args),
nprocs=nprocs,
join=True)
join=True
)