Skip to content

Commit

Permalink
Added correct expected weights, fixed adam state init (tracel-ai#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
wdoppenberg authored Aug 9, 2023
1 parent c1ba355 commit 4c663b4
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl AdaptiveMomentum {
let factor = 1.0 - self.beta_2;
let moment_2 = grad.powf(2.0).mul_scalar(factor);

AdaptiveMomentumState::new(1, moment_1, moment_2)
AdaptiveMomentumState::new(0, moment_1, moment_2)
};

let time = (state.time as i32).elem();
Expand Down Expand Up @@ -219,6 +219,7 @@ mod tests {

assert_eq!(state_optim_before.len(), state_optim_after.len());
}
const ASSERT_PRECISION: usize = 6;

#[test]
fn test_adam_optimizer_with_numbers() {
Expand Down Expand Up @@ -248,6 +249,7 @@ mod tests {
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();

let grads = linear.forward(x_1).backward();
Expand All @@ -259,45 +261,42 @@ mod tests {
let linear = optimizer.step(LEARNING_RATE, linear, grads);

let state_updated = linear.into_record();
let state_expected = given_linear_record(
Data::from([
[-0.3405, 0.1191, 0.3843, 0.3000, 0.0661, 0.0471],
[0.0577, -0.0367, -0.3846, 0.2360, 0.1756, -0.3122],
[-0.0389, 0.0150, -0.3161, 0.2284, -0.2978, 0.2930],
[-0.3180, -0.2396, -0.3915, -0.3181, -0.0960, 0.1427],
[0.3100, -0.2365, 0.3517, -0.1929, 0.3597, -0.0504],
[-0.0358, -0.0303, 0.1059, 0.1721, 0.0095, 0.3634],
]),
Data::from([-0.4105, 0.0684, -0.1170, 0.0976, 0.1166, -0.0070]),
);
let weights_expected = Data::from([
[-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
[
0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
],
[
-0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
],
[
-0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
],
[
0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
],
[-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
]);
let bias_expected = Data::from([
-0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
]);

let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
let (weight_expected, bias_expected) = (
state_expected.weight.to_data(),
state_expected.bias.unwrap().to_data(),
);

bias_updated.assert_approx_eq(&bias_expected, 2);
weight_updated.assert_approx_eq(&weight_expected, 2);
bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION);
weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION);
}

fn given_linear_layer(weight: Data<f32, 2>, bias: Data<f32, 1>) -> nn::Linear<TestADBackend> {
let linear = nn::LinearConfig::new(6, 6).init();
let record = given_linear_record(weight, bias);

linear.load_record(record)
}

fn given_linear_record(
weight: Data<f32, 2>,
bias: Data<f32, 1>,
) -> nn::LinearRecord<TestADBackend> {
nn::LinearRecord {
let record = nn::LinearRecord {
weight: Param::from(Tensor::from_data(weight)),
bias: Some(Param::from(Tensor::from_data(bias))),
}
};

nn::LinearConfig::new(6, 6).init_with(record)
}

fn create_adam() -> OptimizerAdaptor<Adam<TestBackend>, nn::Linear<TestADBackend>, TestADBackend>
Expand Down

0 comments on commit 4c663b4

Please sign in to comment.