import ttach as tta
import multiprocessing.pool as mpp
import multiprocessing as mp
import time
from train_supervision import *
import argparse
from pathlib import Path
import cv2
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def label2rgb(mask):
h, w = mask.shape[0], mask.shape[1]
mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
mask_convert = mask[np.newaxis, :, :]
mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0]
mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255]
mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0]
mask_rgb[np.all(mask_convert == 2, axis=0)] = [255, 255, 0]
mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 204, 255]
mask_rgb[np.all(mask_convert == 5, axis=0)] = [0, 0, 255]
return mask_rgb
def img_writer(inp):
(mask, mask_id, rgb) = inp
if rgb:
mask_name_tif = mask_id + '.png'
mask_tif = label2rgb(mask)
cv2.imwrite(mask_name_tif, mask_tif)
else:
mask_png = mask.astype(np.uint8)
mask_name_png = mask_id + '.png'
cv2.imwrite(mask_name_png, mask_png)
def get_args():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg("-c", "--config_path", type=Path, required=True, help="Path to config")
arg("-o", "--output_path", type=Path, help="Path where to save resulting masks.", required=True)
arg("-t", "--tta", help="Test time augmentation.", default=None, choices=[None, "d4", "lr"])
arg("--rgb", help="whether output rgb images", action='store_true')
return parser.parse_args()
def main():
seed_everything(42)
args = get_args()
config = py2cfg(args.config_path)
args.output_path.mkdir(exist_ok=True, parents=True)
model = Supervision_Train.load_from_checkpoint(os.path.join(config.weights_path, config.test_weights_name+'.ckpt'), config=config)
model.cuda()
model.eval()
evaluator = Evaluator(num_class=config.num_classes)
evaluator.reset()
if args.tta == "lr":
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip()
]
)
model = tta.SegmentationTTAWrapper(model, transforms)
elif args.tta == "d4":
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
tta.Rotate90(angles=[90]),
tta.Scale(scales=[0.5, 0.75, 1.0, 1.25, 1.5], interpolation='bicubic', align_corners=False)
]
)
model = tta.SegmentationTTAWrapper(model, transforms)
test_dataset = config.test_dataset
with torch.no_grad():
test_loader = DataLoader(
test_dataset,
batch_size=2,
num_workers=4,
pin_memory=True,
drop_last=False,
)
results = []
for input in tqdm(test_loader):
# raw_prediction NxCxHxW
raw_predictions = model(input['img'].cuda())
image_ids = input["img_id"]
masks_true = input['gt_semantic_seg']
raw_predictions = nn.Softmax(dim=1)(raw_predictions)
predictions = raw_predictions.argmax(dim=1)
for i in range(raw_predictions.shape[0]):
mask = predictions[i].cpu().numpy()
evaluator.add_batch(pre_image=mask, gt_image=masks_true[i].cpu().numpy())
mask_name = image_ids[i]
results.append((mask, str(args.output_path / mask_name), args.rgb))
iou_per_class = evaluator.Intersection_over_Union()
f1_per_class = evaluator.F1()
OA = evaluator.OA()
for class_name, class_iou, class_f1 in zip(config.classes, iou_per_class, f1_per_class):
print('F1_{}:{}, IOU_{}:{}'.format(class_name, class_f1, class_name, class_iou))
print('F1:{}, mIOU:{}, OA:{}'.format(np.nanmean(f1_per_class[:-1]), np.nanmean(iou_per_class[:-1]), OA))
t0 = time.time()
mpp.Pool(processes=mp.cpu_count()).map(img_writer, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))
if __name__ == "__main__":
main()