Skip to content

Commit

Permalink
master
Browse files Browse the repository at this point in the history
  • Loading branch information
iser97 committed Oct 21, 2021
1 parent 5baf28d commit cb464f0
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
mpl.use("agg")

import numpy as np
import sklearn as sk
import sklearn.metrics as skm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers.utils import logging
import logging

from scripts.model.transformer_single_layer import my_transformer
from scripts.data.dataset_mnist_8_8 import DatasetMnist
logger = logging.get_logger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
torch.manual_seed(71)

def make_train_step(model, loss_fn, optimizer):
Expand Down Expand Up @@ -51,10 +56,10 @@ def test_step(model, data_loader):
pred = pred.cpu().tolist()
preds = preds + pred
labels = labels + y_batch.cpu().tolist()
confusion_matrix = sk.metrics.confusion_matrix(labels, preds)
acc = sk.metrics.accuracy_score(labels, preds)
recall = sk.metrics.recall_score(labels, preds, average='macro')
f1 = sk.metrics.f1_score(np.array(labels), np.array(preds), average='macro')
confusion_matrix = skm.confusion_matrix(labels, preds)
acc = skm.accuracy_score(labels, preds)
recall = skm.recall_score(labels, preds, average='macro')
f1 = skm.f1_score(np.array(labels), np.array(preds), average='macro')

logger.info(f"confusion_matrix = {confusion_matrix}")
logger.info(f"ACC = {acc}")
Expand Down Expand Up @@ -101,9 +106,10 @@ def main():


if __name__ == '__main__':

device = 'cuda' if torch.cuda.is_available() else 'cpu'
### training parameters
batch_size = 10
batch_size = 256
lr = 1e-4
mom = 0.91
n_epochs = 2000
Expand All @@ -115,7 +121,6 @@ def main():
n_heads = 4
num_classes = 10


main()


0 comments on commit cb464f0

Please sign in to comment.