diff --git a/VAE/CVAE.py b/VAE/CVAE.py index b5d5923..da97aba 100644 --- a/VAE/CVAE.py +++ b/VAE/CVAE.py @@ -297,6 +297,13 @@ def eval_img(self, batch, cond_mask=None): pred_img, recon_img = self.decode(z_mu, pose_embeds) return pred_img, recon_img + def freeze_encoder(self): + self.ec1.requires_grad_(False) + self.ec2.requires_grad_(False) + self.ec3.requires_grad_(False) + self.ec4.requires_grad_(False) + self.fc1.requires_grad_(False) + class PoseMapping(nn.Module): '''Map the pose(quaternion) to two vectors''' def __init__(self, embed:int = 64) -> None: diff --git a/VAE/train.py b/VAE/train.py index 607c5d0..643c62d 100644 --- a/VAE/train.py +++ b/VAE/train.py @@ -67,10 +67,15 @@ def warmup(optimizer, step, last_step, last_lr): def train(model, optimizer, loader, loader_val, writer, now, step, args): a = 1 + freezed = True ## No freezing for e in range(args.num_epochs): print(f'starting epoch {e}') for img, R, T, K in tqdm(loader): + if not freezed and step > args.freeze_step: + print('freezing encoder') + model.module.freeze_encoder() + freezed = True warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr) B = img.shape[0] @@ -151,10 +156,13 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'): parser.add_argument('--transfer', type=str, default="") parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--num_epochs', type=int, default=100000) - parser.add_argument('--warmup_step', type=int, default=10000000) + + parser.add_argument('--warmup_step', type=int, default=0) + parser.add_argument('--freeze_step', type=float, default=5000) + parser.add_argument('--verbose_interval', type=int, default=500) parser.add_argument('--validation_interval', type=int, default=1000) - parser.add_argument('--save_interval', type=int, default=20) + parser.add_argument('--save_interval', type=int, default=10) parser.add_argument('--save_path', type=str, default="./results") opts = parser.parse_args() main(opts)