From cb464f06a0be4e8981423057b5c3de408894845c Mon Sep 17 00:00:00 2001 From: iser97 <1156653418@qq.com> Date: Thu, 21 Oct 2021 18:12:43 +0800 Subject: [PATCH] master --- main.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index ecec4fc..109526f 100644 --- a/main.py +++ b/main.py @@ -5,18 +5,23 @@ mpl.use("agg") import numpy as np -import sklearn as sk +import sklearn.metrics as skm import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader -from transformers.utils import logging +import logging from scripts.model.transformer_single_layer import my_transformer from scripts.data.dataset_mnist_8_8 import DatasetMnist -logger = logging.get_logger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) torch.manual_seed(71) def make_train_step(model, loss_fn, optimizer): @@ -51,10 +56,10 @@ def test_step(model, data_loader): pred = pred.cpu().tolist() preds = preds + pred labels = labels + y_batch.cpu().tolist() - confusion_matrix = sk.metrics.confusion_matrix(labels, preds) - acc = sk.metrics.accuracy_score(labels, preds) - recall = sk.metrics.recall_score(labels, preds, average='macro') - f1 = sk.metrics.f1_score(np.array(labels), np.array(preds), average='macro') + confusion_matrix = skm.confusion_matrix(labels, preds) + acc = skm.accuracy_score(labels, preds) + recall = skm.recall_score(labels, preds, average='macro') + f1 = skm.f1_score(np.array(labels), np.array(preds), average='macro') logger.info(f"confusion_matrix = {confusion_matrix}") logger.info(f"ACC = {acc}") @@ -101,9 +106,10 @@ def main(): if __name__ == '__main__': + device = 'cuda' if torch.cuda.is_available() else 'cpu' ### training parameters - batch_size = 10 + batch_size = 256 lr = 1e-4 mom = 0.91 n_epochs = 2000 @@ -115,7 +121,6 @@ def main(): n_heads = 4 num_classes = 10 - main()