from __future__ import print_function from data.dataset import Dataset from torch.utils import data from models.focal_loss import FocalLoss from models.metrics import * from utils.visualizer import Visualizer import time from config.config import Config from torch.nn import DataParallel from torch.optim.lr_scheduler import StepLR from val import * torch.manual_seed(3407) def save_model(model, save_path, name, iter_cnt): os.makedirs(save_path, exist_ok=True) save_name = os.path.join(save_path, name + '_' + str(iter_cnt) + '.pth') torch.save(model.state_dict(), save_name) return save_name if __name__ == '__main__': opt = Config() if opt.display: visualizer = Visualizer() device = torch.device("cuda") train_dataset = Dataset(opt.train_root, opt.train_list, phase='train', input_shape=opt.input_shape, mp3aug_ratio=opt.mp3aug_ratio, npy_aug=opt.npy_aug) trainloader = data.DataLoader(train_dataset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers) val_path = os.path.join(opt.train_root, opt.val_list) print('{} train iters per epoch:'.format(len(trainloader))) if opt.loss == 'focal_loss': criterion = FocalLoss(gamma=2) else: criterion = torch.nn.CrossEntropyLoss() if opt.backbone == 'resnet18': model = resnet_face18(use_se=opt.use_se) elif opt.backbone == 'resnet34': model = resnet34() elif opt.backbone == 'resnet50': model = resnet50() if opt.metric == 'add_margin': metric_fc = AddMarginProduct(512, opt.num_classes, s=30, m=0.35) elif opt.metric == 'arc_margin': metric_fc = ArcMarginProduct(512, opt.num_classes, s=30, m=0.5, easy_margin=opt.easy_margin) elif opt.metric == 'sphere': metric_fc = SphereProduct(512, opt.num_classes, m=4) else: metric_fc = nn.Linear(512, opt.num_classes) model.to(device) model = DataParallel(model) metric_fc.to(device) metric_fc = DataParallel(metric_fc) if opt.optimizer == 'sgd': optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}], lr=opt.lr, weight_decay=opt.weight_decay) else: optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': metric_fc.parameters()}], lr=opt.lr, weight_decay=opt.weight_decay) scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1) start = time.time() mrr_best = 0 for i in range(1, opt.max_epoch + 1): scheduler.step() model.train() for ii, data in enumerate(trainloader): data_input, label = data data_input = data_input.to(device) label = label.to(device).long() feature = model(data_input) output = metric_fc(feature, label) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step() iters = i * len(trainloader) + ii if iters % opt.print_freq == 0: output = output.data.cpu().numpy() output = np.argmax(output, axis=1) label = label.data.cpu().numpy() acc = np.mean((output == label).astype(int)) speed = opt.print_freq / (time.time() - start) time_str = time.asctime(time.localtime(time.time())) print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.format(time_str, i, ii, speed, loss.item(), acc)) start = time.time() if i % opt.save_interval == 0 or i == opt.max_epoch: print('calculating mrr .....') save_model(model, opt.checkpoints_path, opt.backbone, 'latest') model.eval() data_val = read_val(opt.val_list, opt.train_root) mrr = mrr_score(model, data_val, opt.input_shape) print(f'epoch {i}: MRR= {mrr}') if mrr > mrr_best: mrr_best = mrr save_model(model, opt.checkpoints_path, opt.backbone, 'best')