Skip to content

Commit

Permalink
AdamW NaN fix (tracel-ai#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Oct 24, 2023
1 parent 1fd5955 commit 0ab611b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
74 changes: 41 additions & 33 deletions burn-core/src/optim/adamw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl AdaptiveMomentumW {
let factor = 1.0 - self.beta_2;
let moment_2 = grad.powf(2.0).mul_scalar(factor);

AdaptiveMomentumWState::new(0, moment_1, moment_2)
AdaptiveMomentumWState::new(1, moment_1, moment_2)
};

let time: i32 = (state.time as i32).elem();
Expand Down Expand Up @@ -228,7 +228,7 @@ mod tests {
assert_eq!(state_optim_before.len(), state_optim_after.len());
}

const ASSERT_PRECISION: usize = 6;
const ASSERT_PRECISION: usize = 2;

#[test]
fn test_adamw_optimizer_with_numbers() {
Expand Down Expand Up @@ -290,37 +290,6 @@ mod tests {
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
]);

let t_state_updated: Tensor<TestADBackend, 2> =
Tensor::from_data(state_updated.weight.to_data());
let t_state_expected: Tensor<TestADBackend, 2> =
Tensor::from_data(weights_expected.clone());

let t_actual_difference = t_state_updated.sub(t_state_expected);
let expected_difference: Tensor<TestADBackend, 2> = Tensor::from_floats([
[
-0.016695, -0.019573, -0.023942, -0.023132, -0.020668, -0.020566,
],
[
-0.020668, -0.018018, -0.016251, -0.022484, -0.021762, -0.016982,
],
[
-0.019703, -0.018548, -0.016955, -0.022418, -0.017039, -0.023019,
],
[
-0.016920, -0.015994, -0.016204, -0.016967, -0.019053, -0.021519,
],
[
-0.023185, -0.016026, -0.023617, -0.018215, -0.023598, -0.019593,
],
[
-0.019734, -0.018083, -0.021164, -0.021856, -0.020104, -0.023720,
],
]);

t_actual_difference
.into_data()
.assert_approx_eq(&expected_difference.into_data(), ASSERT_PRECISION);

let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
Expand All @@ -330,6 +299,45 @@ mod tests {
weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION);
}

#[test]
fn test_adam_optimizer_no_nan() {
let linear = given_linear_layer(
Data::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);

let x = Tensor::from_floats([
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
])
.require_grad();

let mut optimizer = AdamWConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(0.5)
.init();

let grads = linear.forward(x.clone()).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);

let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);

let state_updated = linear.into_record();
assert!(!state_updated.weight.to_data().value[0].is_nan());
}

fn given_linear_layer(weight: Data<f32, 2>, bias: Data<f32, 1>) -> nn::Linear<TestADBackend> {
let record = nn::LinearRecord {
weight: Param::from(Tensor::from_data(weight)),
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::model::Model;

use burn::module::Module;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::AdamConfig;
use burn::optim::{AdamConfig, AdamWConfig};
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
use burn::train::metric::store::{Aggregate, Direction, Split};
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
Expand Down

0 comments on commit 0ab611b

Please sign in to comment.