Skip to content

Commit

Permalink
cam generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrey authored and Andrey committed Jan 29, 2018
1 parent c028508 commit bbe30b5
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
104 changes: 104 additions & 0 deletions HeatmapGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import numpy as np
import time
import sys
from PIL import Image

import cv2

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from DensenetModels import DenseNet121
from DensenetModels import DenseNet169
from DensenetModels import DenseNet201

#--------------------------------------------------------------------------------
#---- Class to generate heatmaps (CAM)

class HeatmapGenerator ():

#---- Initialize heatmap generator
#---- pathModel - path to the trained densenet model
#---- nnArchitecture - architecture name DENSE-NET121, DENSE-NET169, DENSE-NET201
#---- nnClassCount - class count, 14 for chxray-14


def __init__ (self, pathModel, nnArchitecture, nnClassCount, transCrop):

#---- Initialize the network
if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, True).cuda()

model = torch.nn.DataParallel(model).cuda()

modelCheckpoint = torch.load(pathModel)
model.load_state_dict(modelCheckpoint['state_dict'])

self.model = model.module.densenet121.features
self.model.eval()

#---- Initialize the weights
self.weights = list(self.model.parameters())[-2]

#---- Initialize the image transform - resize + normalize
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.Resize(transCrop))
transformList.append(transforms.ToTensor())
transformList.append(normalize)

self.transformSequence = transforms.Compose(transformList)

#--------------------------------------------------------------------------------

def generate (self, pathImageFile, pathOutputFile, transCrop):

#---- Load image, transform, convert
imageData = Image.open(pathImageFile).convert('RGB')
imageData = self.transformSequence(imageData)
imageData = imageData.unsqueeze_(0)

input = torch.autograd.Variable(imageData)

self.model.cuda()
output = self.model(input.cuda())

#---- Generate heatmap
heatmap = None
for i in range (0, len(self.weights)):
map = output[0,i,:,:]
if i == 0: heatmap = self.weights[i] * map
else: heatmap += self.weights[i] * map

#---- Blend original and heatmap
npHeatmap = heatmap.cpu().data.numpy()

imgOriginal = cv2.imread(pathImageFile, 1)
imgOriginal = cv2.resize(imgOriginal, (transCrop, transCrop))

cam = npHeatmap / np.max(npHeatmap)
cam = cv2.resize(cam, (transCrop, transCrop))
heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)

img = heatmap * 0.5 + imgOriginal

cv2.imwrite(pathOutputFile, img)

#--------------------------------------------------------------------------------

pathInputImage = 'test/00009285_000.png'
pathOutputImage = 'test/heatmap.png'
pathModel = 'models/m-25012018-123527.pth.tar'

nnArchitecture = 'DENSE-NET-121'
nnClassCount = 14

transCrop = 224

h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, transCrop)
h.generate(pathInputImage, pathOutputImage, transCrop)
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ The highest accuracy evaluated with AUROC was 0.8508 (see the model m-25012018-1
The same training (70%), validation (10%) and testing (20%) datasets were used as in [this](https://github.com/arnoweng/CheXNet)
implementation.

![alt text](test/heatmap.png)

## Prerequisites
* Python 3.5.2
* Pytorch
* OpenCV (for generating CAMs)

## Usage
* Download the ChestX-ray14 database from [here](https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/37178474737)
Expand All @@ -19,7 +22,9 @@ implementation.
* Use the **runTrain()** function in the **Main.py** to train a model from scratch

This implementation allows to conduct experiments with 3 different densenet architectures: densenet-121, densenet-169 and
densenet-201.
densenet-201.

* To generate CAM of a test file run script HeatmapGenerator

## Results
The highest accuracy 0.8508 was achieved by the model m-25012018-123527 (see the models directory).
Expand Down
Binary file added test/00009285_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/heatmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit bbe30b5

Please sign in to comment.