Skip to content

Commit

Permalink
Rename heatmaps to preds to avoid confusion
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Jan 24, 2020
1 parent eccf0e4 commit 33176e4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
14 changes: 7 additions & 7 deletions lcnn/models/line_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, backbone):

def forward(self, input_dict):
result = self.backbone(input_dict)
h = result["heatmaps"]
h = result["preds"]
x = self.fc1(result["feature"])
n_batch, n_channel, row, col = x.shape

Expand Down Expand Up @@ -134,16 +134,16 @@ def sum_batch(x):
jcs[i][j] = jcs[i][j][
None, torch.arange(M.n_out_junc) % len(jcs[i][j])
]
result["heatmaps"]["lines"] = torch.cat(lines)
result["heatmaps"]["score"] = torch.cat(score)
result["heatmaps"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
result["preds"]["lines"] = torch.cat(lines)
result["preds"]["score"] = torch.cat(score)
result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
if len(jcs[i]) > 1:
result["heatmaps"]["junts"] = torch.cat(
result["preds"]["junts"] = torch.cat(
[jcs[i][1] for i in range(n_batch)]
)
else:
if "heatmaps" in result:
del result["heatmaps"]
if "preds" in result:
del result["preds"]
return result

def sample_lines(self, meta, jmap, joff, do_evaluation):
Expand Down
2 changes: 1 addition & 1 deletion lcnn/models/multitask_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, input_dict, output_feature=True):
lmap = output[offset[0] : offset[1]].squeeze(0)
joff = output[offset[1] : offset[2]].reshape(n_jtyp, 2, batch, row, col)
if stack == 0:
result["heatmaps"] = {
result["preds"] = {
"jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
"lmap": lmap.sigmoid(),
"joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
Expand Down
2 changes: 1 addition & 1 deletion lcnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def validate(self):

total_loss += self._loss(result)

H = result["heatmaps"]
H = result["preds"]
for i in range(H["jmap"].shape[0]):
index = batch_idx * self.batch_size + i
np.savez(
Expand Down
2 changes: 1 addition & 1 deletion process.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():
"target": recursive_to(target, device),
"do_evaluation": True,
}
H = model(input_dict)["heatmaps"]
H = model(input_dict)["preds"]
for i in range(M.batch_size):
index = batch_idx * M.batch_size + i
np.savez(
Expand Down

0 comments on commit 33176e4

Please sign in to comment.