Skip to content

Commit

Permalink
fix res256 pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
Manchery committed Jan 2, 2025
1 parent 71bec00 commit 92844e7
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion scripts/pretrain/oxe-256-act-free.sh
Original file line number Diff line number Diff line change
@@ -4,6 +4,9 @@
# dataset_path: path to preprocessed OXE dataset
# sthsth_root_path: path to preprocessed SSv2 dataset

# * Load the pre-trained weights from aMUSEd:
# download https://huggingface.co/amused/amused-256/blob/main/vqvae/diffusion_pytorch_model.safetensors into pretrained_models/amused/vqvae

accelerate launch train_tokenizer.py \
--exp_name oxe-256-act-free-tokenizer --output_dir log_vqgan --seed 0 --mixed_precision bf16 \
--model_type ctx_vqgan \
@@ -12,7 +15,8 @@ accelerate launch train_tokenizer.py \
--oxe_data_mixes_type select --resolution 256 --dataloader_num_workers 16 \
--rand_select --video_stepsize 1 --segment_horizon 16 --segment_length 8 --context_length 2 \
--dataset_path {path to preprocessed_OXE} \
--sthsth_root_path {path to preprocessed_SSv2}
--sthsth_root_path {path to preprocessed_SSv2} \
--pretrained_model_name_or_path pretrained_models/amused/vqvae


# Pre-training transformer using four A100-40GB GPUs
2 changes: 1 addition & 1 deletion train_tokenizer.py
Original file line number Diff line number Diff line change
@@ -367,7 +367,7 @@ def main():
low_cpu_mem_usage=False, device_map=None,
ignore_mismatched_sizes=True
)
if args.pretrained_model_name_or_path == "configs/ctx_vae":
if args.pretrained_model_name_or_path == "pretrained_models/amused/vqvae":
model.init_modules()
if args.context_length != model.context_length:
print(

0 comments on commit 92844e7

Please sign in to comment.