From 4c663b4cb7db2ff41d94157ec527f57cf93aca44 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Wed, 9 Aug 2023 23:55:39 +0200 Subject: [PATCH] Added correct expected weights, fixed adam state init (#621) --- burn-core/src/optim/adam.rs | 59 ++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/burn-core/src/optim/adam.rs b/burn-core/src/optim/adam.rs index dae3d88171..c5f7e0abe3 100644 --- a/burn-core/src/optim/adam.rs +++ b/burn-core/src/optim/adam.rs @@ -152,7 +152,7 @@ impl AdaptiveMomentum { let factor = 1.0 - self.beta_2; let moment_2 = grad.powf(2.0).mul_scalar(factor); - AdaptiveMomentumState::new(1, moment_1, moment_2) + AdaptiveMomentumState::new(0, moment_1, moment_2) }; let time = (state.time as i32).elem(); @@ -219,6 +219,7 @@ mod tests { assert_eq!(state_optim_before.len(), state_optim_after.len()); } + const ASSERT_PRECISION: usize = 6; #[test] fn test_adam_optimizer_with_numbers() { @@ -248,6 +249,7 @@ mod tests { .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) + .with_weight_decay(Some(WeightDecayConfig::new(0.5))) .init(); let grads = linear.forward(x_1).backward(); @@ -259,45 +261,42 @@ mod tests { let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); - let state_expected = given_linear_record( - Data::from([ - [-0.3405, 0.1191, 0.3843, 0.3000, 0.0661, 0.0471], - [0.0577, -0.0367, -0.3846, 0.2360, 0.1756, -0.3122], - [-0.0389, 0.0150, -0.3161, 0.2284, -0.2978, 0.2930], - [-0.3180, -0.2396, -0.3915, -0.3181, -0.0960, 0.1427], - [0.3100, -0.2365, 0.3517, -0.1929, 0.3597, -0.0504], - [-0.0358, -0.0303, 0.1059, 0.1721, 0.0095, 0.3634], - ]), - Data::from([-0.4105, 0.0684, -0.1170, 0.0976, 0.1166, -0.0070]), - ); + let weights_expected = Data::from([ + [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], + [ + 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, + ], + [ + -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, + ], + [ + -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, + ], + [ + 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, + ], + [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], + ]); + let bias_expected = Data::from([ + -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, + ]); + let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); - let (weight_expected, bias_expected) = ( - state_expected.weight.to_data(), - state_expected.bias.unwrap().to_data(), - ); - bias_updated.assert_approx_eq(&bias_expected, 2); - weight_updated.assert_approx_eq(&weight_expected, 2); + bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION); + weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION); } fn given_linear_layer(weight: Data, bias: Data) -> nn::Linear { - let linear = nn::LinearConfig::new(6, 6).init(); - let record = given_linear_record(weight, bias); - - linear.load_record(record) - } - - fn given_linear_record( - weight: Data, - bias: Data, - ) -> nn::LinearRecord { - nn::LinearRecord { + let record = nn::LinearRecord { weight: Param::from(Tensor::from_data(weight)), bias: Some(Param::from(Tensor::from_data(bias))), - } + }; + + nn::LinearConfig::new(6, 6).init_with(record) } fn create_adam() -> OptimizerAdaptor, nn::Linear, TestADBackend>