Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Prev Previous commit
Next Next commit
Refactoring
  • Loading branch information
nathanielsimard committed Apr 5, 2023
commit 8405ac80494c031dd4b79da2e44c9f6b2d567391
38 changes: 27 additions & 11 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use crate::{self as burn, module::ADModule, record::Record};

use super::{
decay::{WeightDecay, WeightDecayConfig, WeightDecayState},
SimpleOptimizer,
Optimizer, SimpleOptimizer,
};
use crate::config::Config;
use crate::optim::SimpleModuleOptimizer;
use crate::optim::adaptor::OptimizerAdaptor;
use crate::tensor::{backend::ADBackend, Tensor};
use burn_tensor::{backend::Backend, ElementConversion};

Expand Down Expand Up @@ -81,10 +81,8 @@ impl<B: Backend> SimpleOptimizer<B> for Adam<B> {
}

impl AdamConfig {
pub fn init<B: ADBackend, M: ADModule<B>>(
&self,
) -> SimpleModuleOptimizer<Adam<B::InnerBackend>, M, B> {
let adam = Adam {
pub fn init<B: ADBackend, M: ADModule<B>>(&self) -> impl Optimizer<M, B> {
let optim = Adam {
learning_rate: self.learning_rate.elem(),
momentum: AdaptiveMomentum {
beta_1: self.beta_1,
Expand All @@ -96,8 +94,7 @@ impl AdamConfig {
.as_ref()
.map(|config| WeightDecay::new(config)),
};

SimpleModuleOptimizer::new(adam)
OptimizerAdaptor::from(optim)
}
}

Expand Down Expand Up @@ -177,13 +174,13 @@ mod tests {
use crate::optim::{GradientsParams, Optimizer};
use crate::record::DebugRecordSettings;
use crate::tensor::{Data, Distribution, Tensor};
use crate::{nn, TestADBackend};
use crate::{nn, TestADBackend, TestBackend};

#[test]
fn test_adam_optimizer_save_load_state() {
let linear = nn::LinearConfig::new(6, 6).init();
let x = Tensor::<TestADBackend, 2>::random([2, 6], Distribution::Standard);
let mut optimizer = AdamConfig::new(0.01).init();
let mut optimizer = create_adam();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(linear, grads);
Expand All @@ -194,9 +191,10 @@ mod tests {

let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = AdamConfig::new(0.01).init::<TestADBackend, nn::Linear<TestADBackend>>();
let optimizer = create_adam();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();

assert_eq!(state_optim_before.len(), state_optim_after.len());
}

Expand Down Expand Up @@ -279,4 +277,22 @@ mod tests {
bias: Some(Param::from(Tensor::from_data(bias))),
}
}

fn create_adam() -> OptimizerAdaptor<Adam<TestBackend>, nn::Linear<TestADBackend>, TestADBackend>
{
let config = AdamConfig::new(0.01);
Adam {
learning_rate: config.learning_rate.elem(),
momentum: AdaptiveMomentum {
beta_1: config.beta_1,
beta_2: config.beta_2,
epsilon: config.epsilon,
},
weight_decay: config
.weight_decay
.as_ref()
.map(|config| WeightDecay::new(config)),
}
.into()
}
}
10 changes: 8 additions & 2 deletions burn-core/src/optim/base.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use super::GradientsParams;
use crate::module::ADModule;
use crate::record::Record;
use crate::tensor::backend::ADBackend;

use super::GradientsParams;

/// General trait to optimize [module](ADModule).
pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
/// Optimizer associative type to be used when saving and loading the state.
type Record: Record;

/// Perform the optimizer step using the given gradients. The updated module will be returned.
fn step(&mut self, module: M, grads: GradientsParams) -> M;

/// Get the current state of the optimizer as a [record](Record).
fn to_record(&self) -> Self::Record;

/// Load the state of the optimizer as a [record](Record).
fn load_record(self, record: Self::Record) -> Self;
}
15 changes: 7 additions & 8 deletions burn-core/src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use crate::module::ADModule;

use super::decay::{WeightDecay, WeightDecayConfig, WeightDecayState};
use super::momentum::{MomemtumState, Momentum, MomentumConfig};
use super::{SimpleModuleOptimizer, SimpleOptimizer};
use super::SimpleOptimizer;
use crate::config::Config;
use crate::optim::adaptor::OptimizerAdaptor;
use crate::record::Record;
use crate::tensor::{ElementConversion, Tensor};
use burn_tensor::backend::{ADBackend, Backend};
Expand Down Expand Up @@ -38,21 +39,20 @@ pub struct SgdState<B: Backend, const D: usize> {
impl SgdConfig {
pub fn init<B: ADBackend, M: ADModule<B>>(
&self,
) -> SimpleModuleOptimizer<Sgd<B::InnerBackend>, M, B> {
) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
let learning_rate = self.learning_rate.elem();
let momentum = self.momentum.as_ref().map(|config| Momentum::new(config));
let weight_decay = self
.weight_decay
.as_ref()
.map(|config| WeightDecay::new(config));

let optim = Sgd {
Sgd {
learning_rate,
momentum,
weight_decay,
};

SimpleModuleOptimizer::new(optim)
}
.into()
}
}

Expand Down Expand Up @@ -156,8 +156,7 @@ mod tests {
LinearConfig::new(20, 20).with_bias(true).init()
}

fn sgd_with_all(
) -> SimpleModuleOptimizer<Sgd<TestBackend>, Linear<TestADBackend>, TestADBackend> {
fn sgd_with_all() -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestADBackend>, TestADBackend> {
SgdConfig {
learning_rate: 0.02,
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
Expand Down
Loading