Skip to content

Commit

Permalink
add random seed arg, add precrop args
Browse files Browse the repository at this point in the history
  • Loading branch information
bmild committed Apr 16, 2020
1 parent 9b6572e commit fd624c0
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def create_nerf(args):
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(
args.multires_views, args.i_embed)
output_ch = 5 if args.N_importance > 0 else 4
output_ch = 4
skips = [4]
model = init_nerf_model(
D=args.netdepth, W=args.netwidth,
Expand Down Expand Up @@ -495,6 +495,14 @@ def config_parser():
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
parser.add_argument("--random_seed", type=int, default=None,
help='fix random seed for repeatability')

# pre-crop options
parser.add_argument("--precrop_iters", type=int, default=0,
help='number of steps to train on central crops')
parser.add_argument("--precrop_frac", type=float,
default=.5, help='fraction of img taken for central crops')

# rendering options
parser.add_argument("--N_samples", type=int, default=64,
Expand Down Expand Up @@ -568,6 +576,11 @@ def train():

parser = config_parser()
args = parser.parse_args()

if args.random_seed is not None:
print('Fixing random seed', args.random_seed)
np.random.seed(args.random_seed)
tf.compat.v1.set_random_seed(args.random_seed)

# Load data

Expand Down Expand Up @@ -768,8 +781,18 @@ def train():

if N_rand is not None:
rays_o, rays_d = get_rays(H, W, focal, pose)
coords = tf.stack(tf.meshgrid(
tf.range(H), tf.range(W), indexing='ij'), -1)
if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = tf.stack(tf.meshgrid(
tf.range(H//2 - dH, H//2 + dH),
tf.range(W//2 - dW, W//2 + dW),
indexing='ij'), -1)
if i < 10:
print('precrop', dH, dW, coords[0,0], coords[-1,-1])
else:
coords = tf.stack(tf.meshgrid(
tf.range(H), tf.range(W), indexing='ij'), -1)
coords = tf.reshape(coords, [-1, 2])
select_inds = np.random.choice(
coords.shape[0], size=[N_rand], replace=False)
Expand Down

0 comments on commit fd624c0

Please sign in to comment.