-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtorch_test.py
147 lines (116 loc) · 3.87 KB
/
torch_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Main script to run the testing of the model(ECGNet, Resnet101).
"""
__author__ = "Likith Reddy"
__version__ = "1.0.0"
__email__ = "likith012@gmail.com"
from typing import List
import os
import random
import argparse
import json
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score
import torch
import torch.nn as nn
from preprocessing.preprocess import preprocess
from utils.torch_dataloader import DataGen
from utils.metrics import Metrics, AUC, metric_summary
# Random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
def epoch_run(
model: nn.Module, dataset: torch.utils.data.Dataset, device: torch.device
) -> List[np.array]:
"""Testing of the model.
Parameters
----------
model: nn.Module
Model to be tested.
dataset: torch.utils.data.DataLoader
Dataset to be tested.
device: torch.device
Device to be used.
Returns
-------
np.array
Predicted values.
"""
model.to(device)
model.eval()
pred_all = []
for batch_step in tqdm(range(len(dataset)), desc="test"):
batch_x, _ = dataset[batch_step]
batch_x = batch_x.permute(0, 2, 1).to(device)
pred = model(batch_x)
pred_all.append(pred.detach().cpu().numpy())
pred_all = np.concatenate(pred_all, axis=0)
return pred_all
def test(
model: nn.Module,
path: str = "data/ptb",
batch_size: int = 32,
name: str = "imle_net",
) -> None:
"""Data preprocessing and testing of the model.
Parameters
----------
model: nn.Module
Model to be trained.
path: str, optional
Path to the directory containing the data. (default: 'data/ptb')
batch_size: int, optional
Batch size. (default: 32)
name: str, optional
Name of the model. (default: 'imle_net')
"""
_, _, X_test_scale, y_test, _, _ = preprocess(path=path)
test_gen = DataGen(X_test_scale, y_test, batch_size=batch_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pred = epoch_run(model, test_gen, device)
roc_score = roc_auc_score(y_test, pred, average="macro")
acc, mean_acc = Metrics(y_test, pred)
class_auc = AUC(y_test, pred)
summary = metric_summary(y_test, pred)
print(f"class wise accuracy: {acc}")
print(f"accuracy: {mean_acc}")
print(f"roc_score : {roc_score}")
print(f"class wise AUC : {class_auc}")
print(f"F1 score (Max): {summary[0]}")
print(f"class wise precision, recall, f1 score : {summary}")
logs = dict()
logs["roc_score"] = roc_score
logs["mean_acc"] = mean_acc
logs["accuracy"] = acc
logs["class_auc"] = class_auc
logs["F1 score (Max)"] = summary[0]
logs["class_precision_recall_f1"] = summary
logs_path = os.path.join(os.getcwd(), "logs")
os.makedirs(logs_path, exist_ok=True)
with open(os.path.join(logs_path, f"{name}_test_logs.json"), "w") as json_file:
json.dump(logs, json_file)
if __name__ == "__main__":
"""Main function to run the training of the model."""
# Args parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", type=str, default="data/ptb", help="Ptb-xl dataset location"
)
parser.add_argument(
"--model",
type=str,
default="ecgnet",
help="Select the model to train. (ecgnet, resnet101)",
)
parser.add_argument("--batchsize", type=int, default=32, help="Batch size")
args = parser.parse_args()
if args.model == "ecgnet":
from models.ECGNet import ECGNet
model = ECGNet()
elif args.model == "resnet101":
from models.resnet101 import resnet101
model = resnet101()
path_weights = os.path.join(os.getcwd(), "checkpoints", f"{args.model}_weights.pt")
model.load_state_dict(torch.load(path_weights))
test(model, path=args.data_dir, batch_size=args.batchsize, name=args.model)