diff --git a/robrank/defenses/amd.py b/robrank/defenses/amd.py index f4cc2e5..fa89908 100644 --- a/robrank/defenses/amd.py +++ b/robrank/defenses/amd.py @@ -622,18 +622,11 @@ def hm_training_step(model: th.nn.Module, batch, batch_idx, *, if gradual: model._hm_prev_loss = loss.item() if ics: - loss = loss + 0.5 * ( - model.lossfunc.raw( + loss = loss + 1.0 * model.lossfunc.raw( ap_orig[:len(ap_orig)//2], pnemb[:len(pnemb) // 3], ap_orig[len(ap_orig)//2:], override_margin=0.0) - + model.lossfunc.raw( - output_orig[pos, :], - pnemb[len(pnemb) // 3 : 2*len(pnemb)//3], - output_orig[anc, :], - override_margin=0.0) - ) # logging model.log('Train/loss_orig', loss_orig.item()) model.log('Train/loss_adv', loss.item())