Skip to content

Commit

Permalink
Prevent VQ-VAE codebook update when training the transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson1yan committed Jun 15, 2021
1 parent 76a7c24 commit b234b80
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions videogpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, args):
self.vqvae = VQVAE.load_from_checkpoint(args.vqvae)
for p in self.vqvae.parameters():
p.requires_grad = False
self.vqvae.codebook._need_init = False
self.vqvae.eval()

# ResNet34 for frame conditioning
Expand Down Expand Up @@ -132,6 +133,7 @@ def forward(self, x, targets, cond, decode_step=None, decode_idx=None):
return loss, logits

def training_step(self, batch, batch_idx):
self.vqvae.eval()
x = batch['video']

cond = dict()
Expand Down

0 comments on commit b234b80

Please sign in to comment.