Skip to content

Commit

Permalink
Merge pull request vqdang#109 from alfrei/fix-batch-size-1-error
Browse files Browse the repository at this point in the history
torch.squeeze causes an error for valid batch size 1
  • Loading branch information
vqdang authored Apr 6, 2021
2 parents be8ae2d + 0306601 commit abf6d89
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/hovernet/run_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def train_step(batch_data, run_info):
imgs = imgs.permute(0, 3, 1, 2).contiguous()

# HWC
true_np = torch.squeeze(true_np).to("cuda").type(torch.int64)
true_hv = torch.squeeze(true_hv).to("cuda").type(torch.float32)
true_np = true_np.to("cuda").type(torch.int64)
true_hv = true_hv.to("cuda").type(torch.float32)

true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32)
true_dict = {
Expand Down

0 comments on commit abf6d89

Please sign in to comment.