Skip to content

Commit

Permalink
[Fix] Redundant operations to transfer data between CPU and GPU (vqda…
Browse files Browse the repository at this point in the history
…ng#206)

* fix typo

* remove redundant operations
  • Loading branch information
Kaminyou authored Apr 27, 2022
1 parent 9b21c86 commit 842964d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self):
raise Exception("If using `original` mode, input shape must be [270,270] and output shape must be [80,80]")
if model_mode == "fast":
if act_shape != [256,256] or out_shape != [164,164]:
raise Exception("If using `original` mode, input shape must be [256,256] and output shape must be [164,164]")
raise Exception("If using `fast` mode, input shape must be [256,256] and output shape must be [164,164]")

self.dataset_name = "consep" # extracts dataset info from dataset.py
self.log_dir = "logs/" # where checkpoints will be saved
Expand Down
12 changes: 6 additions & 6 deletions models/hovernet/run_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def valid_step(batch_data, run_info):
imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous()

# HWC
true_np = torch.squeeze(true_np).to("cuda").type(torch.int64)
true_hv = torch.squeeze(true_hv).to("cuda").type(torch.float32)
true_np = torch.squeeze(true_np).type(torch.int64)
true_hv = torch.squeeze(true_hv).type(torch.float32)

true_dict = {
"np": true_np,
Expand All @@ -135,7 +135,7 @@ def valid_step(batch_data, run_info):

if model.module.nr_types is not None:
true_tp = batch_data["tp_map"]
true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64)
true_tp = torch.squeeze(true_tp).type(torch.int64)
true_dict["tp"] = true_tp

# --------------------------------------------------------------
Expand All @@ -155,14 +155,14 @@ def valid_step(batch_data, run_info):
result_dict = { # protocol for contents exchange within `raw`
"raw": {
"imgs": imgs.numpy(),
"true_np": true_dict["np"].cpu().numpy(),
"true_hv": true_dict["hv"].cpu().numpy(),
"true_np": true_dict["np"].numpy(),
"true_hv": true_dict["hv"].numpy(),
"prob_np": pred_dict["np"].cpu().numpy(),
"pred_hv": pred_dict["hv"].cpu().numpy(),
}
}
if model.module.nr_types is not None:
result_dict["raw"]["true_tp"] = true_dict["tp"].cpu().numpy()
result_dict["raw"]["true_tp"] = true_dict["tp"].numpy()
result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy()
return result_dict

Expand Down

0 comments on commit 842964d

Please sign in to comment.