-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Andrey
authored and
Andrey
committed
Jan 29, 2018
1 parent
c028508
commit bbe30b5
Showing
4 changed files
with
110 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
import numpy as np | ||
import time | ||
import sys | ||
from PIL import Image | ||
|
||
import cv2 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.backends.cudnn as cudnn | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
|
||
from DensenetModels import DenseNet121 | ||
from DensenetModels import DenseNet169 | ||
from DensenetModels import DenseNet201 | ||
|
||
#-------------------------------------------------------------------------------- | ||
#---- Class to generate heatmaps (CAM) | ||
|
||
class HeatmapGenerator (): | ||
|
||
#---- Initialize heatmap generator | ||
#---- pathModel - path to the trained densenet model | ||
#---- nnArchitecture - architecture name DENSE-NET121, DENSE-NET169, DENSE-NET201 | ||
#---- nnClassCount - class count, 14 for chxray-14 | ||
|
||
|
||
def __init__ (self, pathModel, nnArchitecture, nnClassCount, transCrop): | ||
|
||
#---- Initialize the network | ||
if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, True).cuda() | ||
elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, True).cuda() | ||
elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, True).cuda() | ||
|
||
model = torch.nn.DataParallel(model).cuda() | ||
|
||
modelCheckpoint = torch.load(pathModel) | ||
model.load_state_dict(modelCheckpoint['state_dict']) | ||
|
||
self.model = model.module.densenet121.features | ||
self.model.eval() | ||
|
||
#---- Initialize the weights | ||
self.weights = list(self.model.parameters())[-2] | ||
|
||
#---- Initialize the image transform - resize + normalize | ||
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
transformList = [] | ||
transformList.append(transforms.Resize(transCrop)) | ||
transformList.append(transforms.ToTensor()) | ||
transformList.append(normalize) | ||
|
||
self.transformSequence = transforms.Compose(transformList) | ||
|
||
#-------------------------------------------------------------------------------- | ||
|
||
def generate (self, pathImageFile, pathOutputFile, transCrop): | ||
|
||
#---- Load image, transform, convert | ||
imageData = Image.open(pathImageFile).convert('RGB') | ||
imageData = self.transformSequence(imageData) | ||
imageData = imageData.unsqueeze_(0) | ||
|
||
input = torch.autograd.Variable(imageData) | ||
|
||
self.model.cuda() | ||
output = self.model(input.cuda()) | ||
|
||
#---- Generate heatmap | ||
heatmap = None | ||
for i in range (0, len(self.weights)): | ||
map = output[0,i,:,:] | ||
if i == 0: heatmap = self.weights[i] * map | ||
else: heatmap += self.weights[i] * map | ||
|
||
#---- Blend original and heatmap | ||
npHeatmap = heatmap.cpu().data.numpy() | ||
|
||
imgOriginal = cv2.imread(pathImageFile, 1) | ||
imgOriginal = cv2.resize(imgOriginal, (transCrop, transCrop)) | ||
|
||
cam = npHeatmap / np.max(npHeatmap) | ||
cam = cv2.resize(cam, (transCrop, transCrop)) | ||
heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) | ||
|
||
img = heatmap * 0.5 + imgOriginal | ||
|
||
cv2.imwrite(pathOutputFile, img) | ||
|
||
#-------------------------------------------------------------------------------- | ||
|
||
pathInputImage = 'test/00009285_000.png' | ||
pathOutputImage = 'test/heatmap.png' | ||
pathModel = 'models/m-25012018-123527.pth.tar' | ||
|
||
nnArchitecture = 'DENSE-NET-121' | ||
nnClassCount = 14 | ||
|
||
transCrop = 224 | ||
|
||
h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, transCrop) | ||
h.generate(pathInputImage, pathOutputImage, transCrop) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.