Skip to content

Commit

Permalink
Update test_ood_squeezenet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KingJamesSong authored Nov 23, 2023
1 parent d8c53f6 commit 72c149d
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion test_ood_squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,38 @@ def iterate_data_rankfeat(data_loader, model, temperature):

return np.array(confs)

#Our proposed RankFeat+RankWeight Score
def iterate_data_rankfeat(data_loader, model, temperature):
confs = []

weight = model.module.classifier[1].weight.data
B, C, H, W = weight.size()
weight = weight.view(B, C * H * W)
weight_sub = power_iteration(weight.unsqueeze(0), iter=100)
weight = weight - weight_sub.squeeze()
weight = weight.view(B, C, H, W)
model.module.classifier[1].weight.data = weight

for b, (x, y) in enumerate(data_loader):
if b % 100 == 0:
print('{} batches processed'.format(b))
inputs = x.cuda()

feat = model.module.features(inputs)
B, C, H, W = feat.size()
feat = feat.view(B, C, H * W)
u, s, v = torch.linalg.svd(feat)
feat = feat - s[:, 0:1].unsqueeze(2) * u[:, :, 0:1].bmm(v[:, 0:1, :])
# if you want to use PI for acceleration, comment the above 2 lines and uncomment the line below
# feat = feat - power_iteration(feat, iter=20)
feat = feat.view(B, C, H, W)
logits = model.module.classifier(feat)

conf = temperature * torch.logsumexp(logits / temperature, dim=1)
confs.extend(conf.data.cpu().numpy())

return np.array(confs)

def iterate_data_react(data_loader, model, temperature):
confs = []
for b, (x, y) in enumerate(data_loader):
Expand Down Expand Up @@ -246,6 +278,11 @@ def run_eval(model, in_loader, out_loader, logger, args, num_classes):
in_scores = iterate_data_rankfeat(in_loader, model, args.temperature_rankfeat)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_rankfeat(out_loader, model, args.temperature_rankfeat)
elif args.score == 'RankFeatWeight':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_rankfeatweight(in_loader, model, args.temperature_rankfeat)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_rankfeatweight(out_loader, model, args.temperature_rankfeat)
elif args.score == 'React':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_react(in_loader, model, args.temperature_react)
Expand Down Expand Up @@ -301,7 +338,7 @@ def main(args):
parser.add_argument("--in_datadir", help="Path to the in-distribution data folder.")
parser.add_argument("--out_datadir", help="Path to the out-of-distribution data folder.")

parser.add_argument('--score', choices=['MSP', 'ODIN', 'Energy', 'Mahalanobis', 'GradNorm','RankFeat','React'], default='RankFeat')
parser.add_argument('--score', choices=['MSP', 'ODIN', 'Energy', 'Mahalanobis', 'GradNorm','RankFeat','RankFeatWeight','React'], default='RankFeatWeight')

# arguments for ODIN
parser.add_argument('--temperature_odin', default=1000, type=int,
Expand Down

0 comments on commit 72c149d

Please sign in to comment.