Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Prev Previous commit
Next Next commit
Update burn-core
  • Loading branch information
nathanielsimard committed Apr 4, 2023
commit 14372568f1f0e1d3a08c6027a8b54c3352aa1688
408 changes: 173 additions & 235 deletions burn-core/src/optim/adam.rs

Large diffs are not rendered by default.

116 changes: 2 additions & 114 deletions burn-core/src/optim/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
use super::mapper::ModuleTensorUpdater;
use super::visitor::{GradientsLoader, GradientsRegister};
use super::GradientsParams;

use crate::module::{ADModule, LoadingError, ParamId, State, StateNamed};
use crate::module::ADModule;
use crate::record::Record;
use crate::tensor::backend::ADBackend;
use crate::tensor::{Data, Tensor};

pub trait ModuleOptimizer<M, B>: Send + Sync
pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
Expand All @@ -18,110 +13,3 @@ where
fn to_record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}

pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
/// Update the tensor parameter using the given the gradients.
fn update_tensor<const D: usize>(
&mut self,
id: &ParamId,
tensor: Tensor<B, D>,
grad: Tensor<B::InnerBackend, D>,
) -> Tensor<B, D>;

/// Update the parameters of the given module using the given the gradients.
fn update_module(&mut self, module: M, grads: GradientsParams) -> M
where
Self: Sized,
{
let mut mapper = ModuleTensorUpdater::new(self, grads);
module.map(&mut mapper)
}

/// Register the optimizer state for a given parameter.
///
/// # Notes
///
/// This should only be called by generated code.
fn register_param_state<const D: usize>(
&self,
_id: &ParamId,
_state: &mut StateNamed<B::FloatElem>,
) {
// By default there is no state to register
}

/// Load the optimizer state for a given parameter.
///
/// # Notes
///
/// This should only be called by generated code.
fn load_param_state<const D: usize>(
&mut self,
_id: &ParamId,
_state: &StateNamed<B::FloatElem>,
_device: &B::Device,
) {
// By default there is no state to load
}

/// Get the optimizer state for a given module.
fn state(&self, module: &M) -> State<B::FloatElem>
where
Self: Sized,
{
let mut state_named = StateNamed::new();
let mut visitor = GradientsRegister::new(self, &mut state_named);

module.visit(&mut visitor);
State::StateNamed(state_named)
}

/// Load the optimizer state for a given module.
fn load(&mut self, module: &M, state: &State<B::FloatElem>) -> Result<(), LoadingError>
where
Self: Sized,
{
let state_named = match state {
State::StateNamed(state) => state,
_ => {
return Err(LoadingError::new(
"Can't load state wrapper to fetch id and data".to_string(),
))
}
};

let mut visitor = GradientsLoader::new(self, state_named);
module.visit(&mut visitor);

Ok(())
}
}

pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId) -> String>(
id: &ParamId,
state: &mut StateNamed<B::FloatElem>,
grads: &GradientsParams,
id_to_key: F,
) {
if let Some(grad) = grads.get::<B::InnerBackend, D>(id) {
let data = State::Data(grad.into_data().serialize());
state.register_state(id_to_key(id).as_str(), data);
};
}

pub(super) fn load_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId) -> String>(
id: &ParamId,
state: &StateNamed<B::FloatElem>,
grads: &mut GradientsParams,
id_to_key: F,
device: &B::Device,
) {
if let Some(State::Data(data)) = state.get(id_to_key(id).as_str()) {
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), device);
grads.register::<B::InnerBackend, D>(id.clone(), tensor);
};
}
57 changes: 19 additions & 38 deletions burn-core/src/optim/decay.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate as burn;
use burn_tensor::backend::Backend;

use super::{load_state_gradients, register_state_gradients, GradientsParams};
use crate as burn;
use crate::record::Record;

use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Tensor};

/// Configuration to create [WeightDecay](WeightDecay).
Expand All @@ -14,53 +13,35 @@ pub struct WeightDecayConfig {
pub penalty: f64,
}

#[derive(Record, Clone, new)]
pub struct WeightDecayState<B: Backend, const D: usize> {
grad_last_step: Tensor<B, D>,
}

/// Weight decay implementation that transforms gradients.
pub struct WeightDecay<B: ADBackend> {
pub struct WeightDecay<B: Backend> {
penalty: B::FloatElem,
gradients: GradientsParams,
}

impl<B: ADBackend> WeightDecay<B> {
impl<B: Backend> WeightDecay<B> {
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty.elem(),
gradients: GradientsParams::new(),
}
}

pub fn transform<const D: usize>(
&mut self,
id: &ParamId,
grad: Tensor<B::InnerBackend, D>,
) -> Tensor<B::InnerBackend, D> {
let grad = match self.gradients.remove::<B::InnerBackend, D>(id) {
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(grad),
None => grad,
};

// Update gradients
self.gradients.register(id.clone(), grad.clone());

grad
}
pub fn register_state<const D: usize>(
&self,
id: &ParamId,
state: &mut StateNamed<B::FloatElem>,
) {
register_state_gradients::<D, B, _>(id, state, &self.gradients, Self::state_key);
}
grad: Tensor<B, D>,
state: Option<WeightDecayState<B, D>>,
) -> (Tensor<B, D>, WeightDecayState<B, D>) {
let grad_last_step = grad.clone();

pub fn load_state<const D: usize>(
&mut self,
id: &ParamId,
state: &StateNamed<B::FloatElem>,
device: &B::Device,
) {
load_state_gradients::<D, B, _>(id, state, &mut self.gradients, Self::state_key, device);
}
let grad = match state {
Some(state) => state.grad_last_step.mul_scalar(self.penalty).add(grad),
None => grad,
};

fn state_key(id: &ParamId) -> String {
format!("weight-decay-{id}")
(grad, WeightDecayState::new(grad_last_step))
}
}
29 changes: 0 additions & 29 deletions burn-core/src/optim/mapper.rs

This file was deleted.

1 change: 0 additions & 1 deletion burn-core/src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ mod adam;
mod base;
mod grad_accum;
mod grads;
mod mapper;
mod sgd;
mod simple;
mod visitor;
Expand Down
62 changes: 20 additions & 42 deletions burn-core/src/optim/momentum.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate as burn;

use super::{load_state_gradients, register_state_gradients, GradientsParams};
use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::record::Record;
use crate::tensor::{ElementConversion, Tensor};
use burn_tensor::backend::Backend;

/// Configuration to create momentum [Momentum](Momentum).
#[derive(Config)]
Expand All @@ -20,66 +19,45 @@ pub struct MomentumConfig {
pub nesterov: bool,
}

#[derive(Record, Clone, new)]
pub struct MomemtumState<B: Backend, const D: usize> {
velocity: Tensor<B, D>,
}

/// Momemtum implementation that transforms gradients.
pub struct Momentum<B: ADBackend> {
pub struct Momentum<B: Backend> {
momentum: B::FloatElem,
dampening: f64,
nesterov: bool,
velocity: GradientsParams,
}

impl<B: ADBackend> Momentum<B> {
impl<B: Backend> Momentum<B> {
pub fn new(config: &MomentumConfig) -> Self {
Self {
momentum: config.momentum.elem(),
dampening: config.dampening,
velocity: GradientsParams::new(),
nesterov: config.nesterov,
}
}

pub fn transform<const D: usize>(
&mut self,
id: &ParamId,
grad: Tensor<B::InnerBackend, D>,
) -> Tensor<B::InnerBackend, D> {
let velocity = match self.velocity.remove::<B::InnerBackend, D>(id) {
Some(grad_last_step) => grad
.clone()
&self,
grad: Tensor<B, D>,
state: Option<MomemtumState<B, D>>,
) -> (Tensor<B, D>, MomemtumState<B, D>) {
let velocity = if let Some(state) = state {
grad.clone()
.mul_scalar(1.0 - self.dampening)
.add(grad_last_step.mul_scalar(self.momentum)),
None => grad.clone(),
.add(state.velocity.mul_scalar(self.momentum))
} else {
grad.clone()
};

let output = match self.nesterov {
let grad = match self.nesterov {
true => velocity.clone().mul_scalar(self.momentum).add(grad),
false => velocity.clone(),
};

// Update velocity
self.velocity.register(id.clone(), velocity);

output
}

pub fn register_state<const D: usize>(
&self,
id: &ParamId,
state: &mut StateNamed<B::FloatElem>,
) {
register_state_gradients::<D, B, _>(id, state, &self.velocity, Self::state_key);
}

pub fn load_state<const D: usize>(
&mut self,
id: &ParamId,
state: &StateNamed<B::FloatElem>,
device: &B::Device,
) {
load_state_gradients::<D, B, _>(id, state, &mut self.velocity, Self::state_key, device);
}

fn state_key(id: &ParamId) -> String {
format!("momentum-{id}")
(grad, MomemtumState::new(velocity))
}
}
Loading