import argparse import os import random import shutil import time import warnings import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.optim import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models from apex import amp from apex.parallel import DistributedDataParallel # 其1,导入库函数 和apex相关,重要 model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--data', metavar='DIR', default='/raid/xianchaow/pytorch-distributed/data/cifar10', help='path to dataset') parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=3200, type=int, metavar='N', help='mini-batch size (default: 6400), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') # 其二,指定当前线程名 parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt # 其三,调用torch中的distributed来管理loss和accuracy的all reduce class data_prefetcher(): # TODO do not use this class, something is wrong! def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() # mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616]) self.mean = torch.tensor([0.4915 * 255, 0.4823 * 255, 0.4468 * 255]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([0.2470 * 255, 0.2435 * 255, 0.2616 * 255]).cuda().view(1, 3, 1, 1) # With Amp, it isn't necessary to manually convert data to half. # if args.fp16: # self.mean = self.mean.half() # self.std = self.std.half() self.preload() def preload(self): try: self.next_input, self.next_target = next(self.loader) except StopIteration: self.next_input = None self.next_target = None return # if record_stream() doesn't work, another option is to make sure device inputs are created # on the main stream. # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') # Need to make sure the memory allocated for next_* is not still in use by the main stream # at the time we start copying to next_*: # self.stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.stream): self.next_input = self.next_input.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True) # more code for the alternative if record_stream() doesn't work: # copy_ will record the use of the pinned source tensor in this side stream. # self.next_input_gpu.copy_(self.next_input, non_blocking=True) # self.next_target_gpu.copy_(self.next_target, non_blocking=True) # self.next_input = self.next_input_gpu # self.next_target = self.next_target_gpu # With Amp, it isn't necessary to manually convert data to half. # if args.fp16: # self.next_input = self.next_input.half() # else: self.next_input = self.next_input.float() self.next_input = self.next_input.sub_(self.mean).div_(self.std) def next(self): torch.cuda.current_stream().wait_stream(self.stream) input = self.next_input target = self.next_target if input is not None: input.record_stream(torch.cuda.current_stream()) if target is not None: target.record_stream(torch.cuda.current_stream()) self.preload() return input, target def main(): args = parser.parse_args() args.nprocs = torch.cuda.device_count() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') main_worker(args.local_rank, args.nprocs, args) # 其四,根据传入的 local rank来调用main worker def main_worker(local_rank, nprocs, args): best_acc1 = .0 dist.init_process_group(backend='nccl') # 其五,初始化线程组,根据nccl通讯协议 # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() torch.cuda.set_device(local_rank)# 重要,指定当前缺省的gpu = current working gpu model.cuda() # 其六,根据local tank 把当前的模型放入local rank所在的gpu # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / nprocs) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model, optimizer = amp.initialize(model, optimizer) # 其七,对模型和优化器进行封装,初始化。和apex相关 model = DistributedDataParallel(model) # 其八,对model进行数据并行化封装。和apex相关 #from apex import amp #from apex.parallel import DistributedDataParallel cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616]) #train_dataset = datasets.ImageFolder( # traindir, # transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ])) train_dataset = datasets.CIFAR10(traindir, train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) #其九,分布式数据采样器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=2, pin_memory=True, sampler=train_sampler) # 使用分布式数据采样器,训练数据集合 val_dataset = datasets.CIFAR10(valdir, train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, local_rank, args) return for epoch in range(args.start_epoch, args.epochs): train_sampler.set_epoch(epoch) # 从而每次epoch的时候,数据重新被shuffle val_sampler.set_epoch(epoch) # 其十,设置每次epoch开始的时候都重新shuf adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, local_rank, args) # evaluate on validation set acc1 = validate(val_loader, model, criterion, local_rank, args) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if args.local_rank == 0: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.module.state_dict(), 'best_acc1': best_acc1, }, is_best) def train(train_loader, model, criterion, optimizer, epoch, local_rank, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() #prefetcher = data_prefetcher(train_loader) #images, target = prefetcher.next() #i = 0 #while images is not None: for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) images = images.cuda(local_rank, non_blocking=True) target = target.cuda(local_rank, non_blocking=True) # 其十一,把mini batch放入当前local rank的gpu # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, local_rank, topk=(1, 5)) torch.distributed.barrier() # 同步点,其十二,同步点,为的是使用all reduce做准备 reduced_loss = reduce_mean(loss, args.nprocs) reduced_acc1 = reduce_mean(acc1, args.nprocs) reduced_acc5 = reduce_mean(acc5, args.nprocs) #其十三,对loss和acc进行规约reduce losses.update(reduced_loss.item(), images.size(0)) top1.update(reduced_acc1.item(), images.size(0)) top5.update(reduced_acc5.item(), images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: #其十四,对loss进行封装,混合精度反向传播。和apex相关 scaled_loss.backward() optimizer.step() #from apex import amp #from apex.parallel import DistributedDataParallel # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) #i += 1 #images, target = prefetcher.next() def validate(val_loader, model, criterion, local_rank, args): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() #prefetcher = data_prefetcher(val_loader) #images, target = prefetcher.next() #i = 0 #while images is not None: for i, (images, target) in enumerate(val_loader): # compute output images = images.cuda(local_rank, non_blocking=True) target = target.cuda(local_rank, non_blocking=True) output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, local_rank, topk=(1, 5)) torch.distributed.barrier() reduced_loss = reduce_mean(loss, args.nprocs) reduced_acc1 = reduce_mean(acc1, args.nprocs) reduced_acc5 = reduce_mean(acc5, args.nprocs) losses.update(reduced_loss.item(), images.size(0)) top1.update(reduced_acc1.item(), images.size(0)) top5.update(reduced_acc5.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) #i += 1 #images, target = prefetcher.next() # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) return top1.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): filename = state['arch'] + '.' + filename torch.save(state, filename) if is_best: filename2 = state['arch'] + '.model_best.pth.tar' shutil.copyfile(filename, filename2) def save_checkpoint1(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def adjust_learning_rate(optimizer, epoch, args): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * (0.1**(epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, local_rank, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): #maxk = max(topk) #batch_size = target.size(0) #_, pred = output.topk(maxk, 1, True, True) #pred = pred.t() #correct = pred.eq(target.view(1, -1).expand_as(pred)) #res = [] #for k in topk: # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) # res.append(correct_k.mul_(100.0 / batch_size)) predy = torch.max(output, 1)[1].data.squeeze() acc = (predy == target).sum().item()/float(target.size(0)) acc = torch.tensor(acc).cuda(local_rank) res = [] res.append(acc) res.append(acc) return res def accuracy2(output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__ == '__main__': main()