Skip to content

Commit

Permalink
Add loss weights to training script
Browse files Browse the repository at this point in the history
  • Loading branch information
oarriaga committed Dec 11, 2019
1 parent a3f9f65 commit 99f1d0d
Showing 1 changed file with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
help='Number of filters in convolutional blocks')
parser.add_argument('-bs', '--batch_size', default=20, type=int,
help='Batch size for training')
parser.add_argument('-lw', '--loss_weights', nargs='+', type=float,
default=[1.0, 1.0, 1.0, 0.2, 0.5],
help='Loss weights in the following order:'
'[consistency, silhouette, separation, pose, variance]')
parser.add_argument('-lr', '--learning_rate', default=0.001, type=float,
help='Initial learning rate for Adam')
parser.add_argument('-is', '--image_size', default=128, type=int,
Expand Down Expand Up @@ -95,13 +99,15 @@
args.depth * 10, args.filters, args.alpha)

# loss instantiation
loss = KeypointNetLoss(args.num_keypoints, focal_length)
name = ['consistency', 'silhouette', 'separation', 'relative_pose', 'variance']
weights = dict(zip(name, args.loss_weights))
loss = KeypointNetLoss(args.num_keypoints, focal_length, loss_weights=weights)
losses = {'uvz_points-shared': loss.uvz_points,
'uv_volumes-shared': loss.uv_volumes}
uvz_point_losses = [loss.consistency, loss.separation, loss.relative_pose]

# metrics
metrics = {'uvz_points-shared': uvz_point_losses,
uvz_point_metrics = [loss.consistency, loss.separation, loss.relative_pose]
metrics = {'uvz_points-shared': uvz_point_metrics,
'uv_volumes-shared': [loss.silhouette, loss.variance]}

# model compilation
Expand Down

0 comments on commit 99f1d0d

Please sign in to comment.