From 8faae0105d3ce3e8eb1e313c940380718a1a7e49 Mon Sep 17 00:00:00 2001 From: pengyc Date: Sat, 20 May 2023 18:32:46 +0800 Subject: [PATCH] fix some problems in pi-gan --- GAN/pi-GAN/datasets.py | 2 +- GAN/pi-GAN/train.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/GAN/pi-GAN/datasets.py b/GAN/pi-GAN/datasets.py index 7a1702f..e1560b5 100644 --- a/GAN/pi-GAN/datasets.py +++ b/GAN/pi-GAN/datasets.py @@ -76,7 +76,7 @@ class SRN(Dataset): - def __init__(self, dataset_path, img_size, **kwargs): + def __init__(self, dataset_path='../../data/SRN/cars_train/*', img_size=128, **kwargs): super().__init__() paths = glob.glob(dataset_path) self.data = [] diff --git a/GAN/pi-GAN/train.py b/GAN/pi-GAN/train.py index 0dc3e46..b7d069b 100644 --- a/GAN/pi-GAN/train.py +++ b/GAN/pi-GAN/train.py @@ -235,7 +235,6 @@ def train(rank, world_size, opt): grad_penalty = 0 g_preds, g_pred_latent, g_pred_position = discriminator_ddp(gen_imgs, alpha, **metadata) - print(g_preds, g_pred_latent, g_pred_position) if metadata['z_lambda'] > 0 or metadata['pos_lambda'] > 0: latent_penalty = torch.nn.MSELoss()(g_pred_latent, z) * metadata['z_lambda'] @@ -385,14 +384,14 @@ def train(rank, world_size, opt): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--n_epochs", type=int, default=3000, help="number of epochs of training") - parser.add_argument("--sample_interval", type=int, default=200, help="interval between image sampling") + parser.add_argument("--n_epochs", type=int, default=30, help="number of epochs of training") + parser.add_argument("--sample_interval", type=int, default=1000, help="interval between image sampling") parser.add_argument('--output_dir', type=str, default='debug') parser.add_argument('--load_dir', type=str, default='') parser.add_argument('--curriculum', type=str, required=True) parser.add_argument('--eval_freq', type=int, default=5000) parser.add_argument('--port', type=str, default='12355') - parser.add_argument('--set_step', type=int, default=None) + parser.add_argument('--set_step', type=int, default=10) parser.add_argument('--model_save_interval', type=int, default=5000) opt = parser.parse_args()