Skip to content

Commit

Permalink
fix some problems in pi-gan
Browse files Browse the repository at this point in the history
  • Loading branch information
FangYuhai committed May 20, 2023
1 parent 461a2e8 commit 8faae01
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion GAN/pi-GAN/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

class SRN(Dataset):

def __init__(self, dataset_path, img_size, **kwargs):
def __init__(self, dataset_path='../../data/SRN/cars_train/*', img_size=128, **kwargs):
super().__init__()
paths = glob.glob(dataset_path)
self.data = []
Expand Down
7 changes: 3 additions & 4 deletions GAN/pi-GAN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def train(rank, world_size, opt):
grad_penalty = 0

g_preds, g_pred_latent, g_pred_position = discriminator_ddp(gen_imgs, alpha, **metadata)
print(g_preds, g_pred_latent, g_pred_position)

if metadata['z_lambda'] > 0 or metadata['pos_lambda'] > 0:
latent_penalty = torch.nn.MSELoss()(g_pred_latent, z) * metadata['z_lambda']
Expand Down Expand Up @@ -385,14 +384,14 @@ def train(rank, world_size, opt):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=3000, help="number of epochs of training")
parser.add_argument("--sample_interval", type=int, default=200, help="interval between image sampling")
parser.add_argument("--n_epochs", type=int, default=30, help="number of epochs of training")
parser.add_argument("--sample_interval", type=int, default=1000, help="interval between image sampling")
parser.add_argument('--output_dir', type=str, default='debug')
parser.add_argument('--load_dir', type=str, default='')
parser.add_argument('--curriculum', type=str, required=True)
parser.add_argument('--eval_freq', type=int, default=5000)
parser.add_argument('--port', type=str, default='12355')
parser.add_argument('--set_step', type=int, default=None)
parser.add_argument('--set_step', type=int, default=10)
parser.add_argument('--model_save_interval', type=int, default=5000)

opt = parser.parse_args()
Expand Down

0 comments on commit 8faae01

Please sign in to comment.