Skip to content

Commit

Permalink
Merge branch 'main' into refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonios Matakos committed Sep 27, 2021
2 parents 806bb7c + 2b3520b commit ebba57e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 36 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
official source code of paper 'PatchmatchNet: Learned Multi-View Patchmatch Stereo'
![](imgs/structure_teaser.jpg)

## Updates
27.09.2021: The codes now allows for Torchscript export.

## Introduction
PatchmatchNet is a novel cascade formulation of learning-based Patchmatch which aims at decreasing memory consumption and computation time for high-resolution multi-view stereo. If you find this project useful for your research, please cite:

Expand Down
45 changes: 19 additions & 26 deletions datasets/dtu_yao.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def read_cam_file(self, filename):
def read_img(self, filename):
img = Image.open(filename)
# scale 0~255 to 0~1

np_img = np.array(img, dtype=np.float32) / 255.
h, w, _ = np_img.shape
np_img_ms = {
Expand All @@ -90,38 +89,34 @@ def prepare_img(self, hr_img):

return hr_img_crop

def read_mask_hr(self, filename):
def read_mask(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32)
np_img = (np_img > 10).astype(np.float32)
np_img = self.prepare_img(np_img)

h, w = np_img.shape
np_img_ms = {
"stage_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
"stage_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
"stage_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
"stage_0": np_img
}
return np_img_ms

return np_img

def read_depth_hr(self, filename):

def read_depth_mask(self, filename, mask_filename, depth_min, depth_max):
depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32)
depth_hr = np.squeeze(depth_hr,2)
depth_lr = self.prepare_img(depth_hr)

mask = self.read_mask(mask_filename)
mask = self.prepare_img(mask)
mask = mask.astype(np.bool_)
mask = mask & (depth_lr>=depth_min) & (depth_lr<=depth_max)
mask = mask.astype(np.float32)

h, w = depth_lr.shape
depth_lr_ms = {
"stage_3": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
"stage_2": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
"stage_1": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
"stage_0": depth_lr
}
return depth_lr_ms
depth_lr_ms = {}
mask_ms = {}

for i in range(self.stages):
depth_cur = cv2.resize(depth_lr, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
depth_lr_ms[f"stage_{i}"] = depth_cur
mask_ms[f"stage_{i}"] = mask_cur

return depth_lr_ms, mask_ms

def __getitem__(self, idx):
meta = self.metas[idx]
Expand Down Expand Up @@ -193,9 +188,7 @@ def __getitem__(self, idx):
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_

mask = self.read_mask_hr(mask_filename_hr)
depth = self.read_depth_hr(depth_filename_hr)
depth, mask = self.read_depth_mask(depth_filename_hr, mask_filename_hr, depth_min, depth_max)
for l in range(self.stages):
mask[f'stage_{l}'] = np.expand_dims(mask[f'stage_{l}'],2)
mask[f'stage_{l}'] = mask[f'stage_{l}'].transpose([2,0,1])
Expand Down
5 changes: 2 additions & 3 deletions models/patchmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ def forward(
), "Patchmatch Evaluation: Different number of images and view weights"

# Change to a tensor with value 1e-5
pixel_wise_weight_sum = 1e-5
# pixel_wise_weight_sum = torch.zeros((batch, 1, 1, height, width), dtype=torch.float32, device=device)
ref_feature = ref_feature.view(batch, self.G, feature_channel // self.G, 1, height, width)
pixel_wise_weight_sum = 1e-5 * torch.ones((batch, 1, 1, height, width), dtype=torch.float32, device=device)
ref_feature = ref_feature.view(batch, self.G, feature_channel // self.G, height, width)
similarity_sum = torch.zeros((batch, self.G, num_depth, height, width), dtype=torch.float32, device=device)

i = 0
Expand Down
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency')
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
parser.add_argument('--parallel', action='store_true', default=False, help='If set use DataParallel; this prevents TorchScript module export.')

parser.add_argument('--patchmatch_iteration', nargs='+', type=int, default=[1,2,2],
help='num of iteration of patchmatch on stages 1,2,3')
Expand Down Expand Up @@ -98,7 +97,7 @@
propagation_range = args.patchmatch_range, patchmatch_iteration=args.patchmatch_iteration,
patchmatch_num_sample = args.patchmatch_num_sample,
propagate_neighbors=args.propagate_neighbors, evaluate_neighbors=args.evaluate_neighbors)
if args.parallel and args.mode in ["train", "val"]:
if args.mode in ["train", "val"]:
model = nn.DataParallel(model)
model.cuda()
model_loss = patchmatchnet_loss
Expand Down Expand Up @@ -162,11 +161,12 @@ def train():
'model': model.state_dict(),
'optimizer': optimizer.state_dict()},
"{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx))
if not args.parallel:
model.eval()
sm = torch.jit.script(model)
sm.save(os.path.join(args.logdir, 'module_{:0>6}.pt'.format(epoch_idx)))
model.train()
# There is only one child here (PatchmatchNet module), but we have to use the iterator to access it
for child_model in model.children():
child_model.eval()
sm = torch.jit.script(child_model)
sm.save(os.path.join(args.logdir, "module_{:0>6}.pt".format(epoch_idx)))
child_model.train()

# testing
avg_test_scalars = DictAverageMeter()
Expand Down

0 comments on commit ebba57e

Please sign in to comment.