Skip to content

Commit

Permalink
vqdang#95 UPD: migrate old view code
Browse files Browse the repository at this point in the history
  • Loading branch information
vqdang committed Jan 29, 2021
1 parent 0ebc2f5 commit b575e66
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
39 changes: 26 additions & 13 deletions models/hovernet/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,39 @@ def gen_targets(ann, crop_shape, **kwargs):


####
def prep_sample(data, **kwargs):
shape_array = [np.array(v.shape[:2]) for v in data.values()]
shape = np.maximum(*shape_array)

def prep_sample(data, is_batch=False, **kwargs):
"""
Designed to process direct output from loader
"""
cmap = plt.get_cmap("jet")

def colorize(ch, vmin, vmax):
def colorize(ch, vmin, vmax, shape):
ch = np.squeeze(ch.astype("float32"))
ch = ch / (vmax - vmin + 1.0e-16)
# take RGB from RGBA heat map
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
ch_cmap = center_pad_to_shape(ch_cmap, shape)
return ch_cmap

viz_list = []
def prep_one_sample(data):
shape_array = [np.array(v.shape[:2]) for v in data.values()]
shape = np.maximum(*shape_array)
viz_list = []
viz_list.append(colorize(data["np_map"], 0, 1, shape))
# map to [0,2] for better visualisation.
# Note, [-1,1] is used for training.
viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2, shape))
viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2, shape))
img = center_pad_to_shape(data["img"], shape)
return np.concatenate([img] + viz_list, axis=1)

# cmap may randomly fails if of other types
viz_list.append(colorize(data["np_map"], 0, 1))
# map to [0,2] for better visualisation.
# Note, [-1,1] is used for training.
viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2))
viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2))
img = center_pad_to_shape(data["img"], shape)
return np.concatenate([img] + viz_list, axis=1)
if is_batch:
viz_list = []
data_shape = list(data.values())[0].shape
for batch_idx in range(data_shape[0]):
sub_data = {k : v[batch_idx] for k, v in data.items()}
viz_list.append(prep_one_sample(sub_data))
return np.concatenate(viz_list, axis=0)
else:
return prep_one_sample(data)
24 changes: 15 additions & 9 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,25 @@ def __init__(self):

####
def view_dataset(self, mode="train"):
"""
Manually change to plt.savefig or plt.show
if using on headless machine or not
"""
self.nr_gpus = 1
import matplotlib.pyplot as plt
check_manual_seed(self.seed)
# TODO: what if each phase want diff annotation ?
phase_list = self.model_config["phase_list"][0]
target_info = phase_list["target_info"]
dataloader = self.get_datagen(1, mode, target_info["gen"])
for batch_data in dataloader: # convert from Tensor to Numpy
batch_data_np = {k: v.numpy() for k, v in batch_data.items()}
# TODO: a separate func, not static method ?
FileLoader.view(batch_data_np, target_info["viz"])
continue
prep_func, prep_kwargs = target_info["viz"]
dataloader = self._get_datagen(2, mode, target_info["gen"])
for batch_data in dataloader:
# convert from Tensor to Numpy
batch_data = {k: v.numpy() for k, v in batch_data.items()}
viz = prep_func(batch_data, is_batch=True, **prep_kwargs)
plt.imshow(viz)
plt.show()
self.nr_gpus = -1
return

####
Expand Down Expand Up @@ -287,9 +296,6 @@ def run(self):
args = docopt(__doc__, version="HoVer-Net v1.0")
trainer = TrainManager()

if args["--view"] and args["--gpu"]:
raise Exception("Supply only one of --view and --gpu.")

if args["--view"]:
if args["--view"] != "train" and args["--view"] != "valid":
raise Exception('Use "train" or "valid" for --view.')
Expand Down

0 comments on commit b575e66

Please sign in to comment.