Skip to content

Commit

Permalink
Rename train.py to train_dtu.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonios Matakos committed Sep 15, 2021
1 parent c167438 commit 19a081d
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 19 deletions.
8 changes: 0 additions & 8 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
import importlib


# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
def find_dataset_def(dataset_name):
module_name = 'datasets.{}'.format(dataset_name)
module = importlib.import_module(module_name)
return getattr(module, "MVSDataset")
9 changes: 2 additions & 7 deletions train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,5 @@

# train on DTU's training set
MVS_TRAINING="/home/mvs_training/dtu/"

python train.py --dataset dtu_yao --batch_size 4 --epochs 8 \
--patchmatch_iteration 1 2 2 --patchmatch_range 6 4 2 \
--patchmatch_num_sample 8 8 16 --propagate_neighbors 0 8 16 --evaluate_neighbors 9 9 9 \
--patchmatch_interval_scale 0.005 0.0125 0.025 \
--trainpath=$MVS_TRAINING --trainlist lists/dtu/train.txt --vallist lists/dtu/val.txt \
--logdir ./checkpoints $@
python train_dtu.py --batch_size 4 --epochs 8 --trainpath=$MVS_TRAINING --trainlist lists/dtu/train.txt \
--vallist lists/dtu/val.txt --logdir ./checkpoints "$@"
5 changes: 1 addition & 4 deletions train.py → train_dtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from typing import List
from torch.utils.tensorboard import SummaryWriter
from datasets import find_dataset_def
from datasets.dtu_yao import MVSDataset
from models import *
from utils import *
import sys
Expand All @@ -21,7 +21,6 @@
parser.add_argument('--mode', default='train', help='train or val', choices=['train', 'val'])
parser.add_argument('--model', default='PatchmatchNet', help='select model')

parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
parser.add_argument('--trainpath', help='train datapath')
parser.add_argument('--valpath', help='validation datapath')
parser.add_argument('--trainlist', help='train list')
Expand Down Expand Up @@ -84,7 +83,6 @@
print_args(args)

# dataset, dataloader
MVSDataset = find_dataset_def(args.dataset)
train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True)
test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5, robust_train=False)

Expand Down Expand Up @@ -326,4 +324,3 @@ def create_stage_images(image: torch.Tensor) -> List[torch.Tensor]:
train()
elif args.mode == "val":
test()

0 comments on commit 19a081d

Please sign in to comment.