Skip to content

Commit

Permalink
add freeze encoder option
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-yf committed May 24, 2023
1 parent dc9276a commit 3e11c16
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
7 changes: 7 additions & 0 deletions VAE/CVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ def eval_img(self, batch, cond_mask=None):
pred_img, recon_img = self.decode(z_mu, pose_embeds)
return pred_img, recon_img

def freeze_encoder(self):
self.ec1.requires_grad_(False)
self.ec2.requires_grad_(False)
self.ec3.requires_grad_(False)
self.ec4.requires_grad_(False)
self.fc1.requires_grad_(False)

class PoseMapping(nn.Module):
'''Map the pose(quaternion) to two vectors'''
def __init__(self, embed:int = 64) -> None:
Expand Down
12 changes: 10 additions & 2 deletions VAE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ def warmup(optimizer, step, last_step, last_lr):

def train(model, optimizer, loader, loader_val, writer, now, step, args):
a = 1
freezed = True ## No freezing
for e in range(args.num_epochs):
print(f'starting epoch {e}')

for img, R, T, K in tqdm(loader):
if not freezed and step > args.freeze_step:
print('freezing encoder')
model.module.freeze_encoder()
freezed = True
warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)

B = img.shape[0]
Expand Down Expand Up @@ -151,10 +156,13 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
parser.add_argument('--transfer', type=str, default="")
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num_epochs', type=int, default=100000)
parser.add_argument('--warmup_step', type=int, default=10000000)

parser.add_argument('--warmup_step', type=int, default=0)
parser.add_argument('--freeze_step', type=float, default=5000)

parser.add_argument('--verbose_interval', type=int, default=500)
parser.add_argument('--validation_interval', type=int, default=1000)
parser.add_argument('--save_interval', type=int, default=20)
parser.add_argument('--save_interval', type=int, default=10)
parser.add_argument('--save_path', type=str, default="./results")
opts = parser.parse_args()
main(opts)

0 comments on commit 3e11c16

Please sign in to comment.