Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Prev Previous commit
Next Next commit
Adapt burn train
  • Loading branch information
nathanielsimard committed Apr 4, 2023
commit dd55207a77c297337bdd9b74bd51d8b682823493
5 changes: 4 additions & 1 deletion burn-core/src/module/param/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ where
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Sync + Clone + Copy,
T::Record: Debug,
B: Backend,
{
type Record = [T::Record; N];
Expand Down Expand Up @@ -201,7 +202,9 @@ where
impl<const N: usize, T, B> ADModule<B> for [T; N]
where
T: ADModule<B> + Debug + Send + Sync + Clone + Copy,
T::InnerModule: Copy,
T::InnerModule: Copy + Debug,
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
<T as Module<B>>::Record: Debug,
B: ADBackend,
{
type InnerModule = [T::InnerModule; N];
Expand Down
25 changes: 15 additions & 10 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::{self as burn, record::Record};
use crate::{self as burn, module::ADModule, record::Record};

use super::{
decay::{WeightDecay, WeightDecayConfig, WeightDecayState},
SimpleOptimizer,
};
use crate::config::Config;
use crate::optim::SimpleModuleOptimizer;
use crate::tensor::{backend::ADBackend, Tensor};
use burn_tensor::{backend::Backend, ElementConversion};

Expand Down Expand Up @@ -70,20 +71,24 @@ impl<B: Backend> SimpleOptimizer<B> for Adam<B> {
}
}

impl<B: ADBackend> Adam<B> {
pub fn new(config: &AdamConfig) -> Self {
Self {
learning_rate: config.learning_rate.elem(),
impl AdamConfig {
pub fn init<B: ADBackend, M: ADModule<B>>(
&self,
) -> SimpleModuleOptimizer<Adam<B::InnerBackend>, M, B> {
let adam = Adam {
learning_rate: self.learning_rate.elem(),
momentum: AdaptiveMomentum {
beta_1: config.beta_1,
beta_2: config.beta_2,
epsilon: config.epsilon,
beta_1: self.beta_1,
beta_2: self.beta_2,
epsilon: self.epsilon,
},
weight_decay: config
weight_decay: self
.weight_decay
.as_ref()
.map(|config| WeightDecay::new(config)),
}
};

SimpleModuleOptimizer::new(adam)
}
}

Expand Down
4 changes: 3 additions & 1 deletion burn-core/src/optim/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ use crate::module::ADModule;
use crate::record::Record;
use crate::tensor::backend::ADBackend;

use super::GradientsParams;

pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
type Record: Record;

fn step(&mut self, module: M, grads: B::Gradients) -> M;
fn step(&mut self, module: M, grads: GradientsParams) -> M;
fn to_record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}
23 changes: 14 additions & 9 deletions burn-core/src/optim/sgd.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate as burn;
use crate::module::ADModule;

use super::decay::{WeightDecay, WeightDecayConfig, WeightDecayState};
use super::momentum::{MomemtumState, Momentum, MomentumConfig};
use super::SimpleOptimizer;
use super::{SimpleModuleOptimizer, SimpleOptimizer};
use crate::config::Config;
use crate::record::Record;
use crate::tensor::{ElementConversion, Tensor};
use burn_tensor::backend::Backend;
use burn_tensor::backend::{ADBackend, Backend};

/// Configuration to create the [Sgd](Sgd) optimizer.
#[derive(Config)]
Expand Down Expand Up @@ -34,20 +35,24 @@ pub struct SgdState<B: Backend, const D: usize> {
momentum: Option<MomemtumState<B, D>>,
}

impl<B: Backend> Sgd<B> {
pub fn new(config: &SgdConfig) -> Self {
let learning_rate = config.learning_rate.elem();
let momentum = config.momentum.as_ref().map(|config| Momentum::new(config));
let weight_decay = config
impl SgdConfig {
pub fn init<B: ADBackend, M: ADModule<B>>(
&self,
) -> SimpleModuleOptimizer<Sgd<B::InnerBackend>, M, B> {
let learning_rate = self.learning_rate.elem();
let momentum = self.momentum.as_ref().map(|config| Momentum::new(config));
let weight_decay = self
.weight_decay
.as_ref()
.map(|config| WeightDecay::new(config));

Self {
let optim = Sgd {
learning_rate,
momentum,
weight_decay,
}
};

SimpleModuleOptimizer::new(optim)
}
}

Expand Down
28 changes: 23 additions & 5 deletions burn-core/src/optim/simple.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::{any::Any, marker::PhantomData};

use super::Optimizer;
use super::{GradientsParams, Optimizer};
use crate::{
module::{ADModule, ModuleMapper, ParamId},
record::{Record, RecordSettings},
Expand All @@ -9,7 +9,8 @@ use burn_tensor::{
backend::{ADBackend, Backend},
Tensor,
};
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Simple optimizer is a more opinionated trait where the state can be generic over the
/// dimension D and implements Record. This allows for simpler optimizer implementations where they
Expand Down Expand Up @@ -40,6 +41,21 @@ where
module: PhantomData<M>,
}

impl<O, B, M> SimpleModuleOptimizer<O, M, B>
where
B: ADBackend,
M: ADModule<B>,
O: SimpleOptimizer<B::InnerBackend>,
{
pub fn new(optim: O) -> Self {
Self {
optim,
records: HashMap::new(),
module: PhantomData::default(),
}
}
}

impl<O, B, M> Optimizer<M, B> for SimpleModuleOptimizer<O, M, B>
where
B: ADBackend,
Expand All @@ -48,7 +64,7 @@ where
{
type Record = HashMap<ParamId, SimpleOptimizerRecord<O, B::InnerBackend>>;

fn step(&mut self, module: M, mut grads: <B as ADBackend>::Gradients) -> M {
fn step(&mut self, module: M, mut grads: GradientsParams) -> M {
let mut mapper =
SimpleModuleOptimizerMapper::<M, B, O>::new(&self.optim, &mut self.records, &mut grads);
module.map(&mut mapper)
Expand All @@ -73,7 +89,7 @@ where
{
optimizer: &'a O,
records: &'a mut HashMap<ParamId, SimpleOptimizerRecord<O, B::InnerBackend>>,
grads: &'a mut B::Gradients,
grads: &'a mut GradientsParams,
phatom: PhantomData<M>,
}

Expand All @@ -84,7 +100,7 @@ where
O: SimpleOptimizer<B::InnerBackend>,
{
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let grad = tensor.grad(self.grads);
let grad = self.grads.remove(id);

if let Some(grad) = grad {
let (key, record) = self.records.remove_entry(id).unzip();
Expand Down Expand Up @@ -134,6 +150,8 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for SimpleOptimizerRecord<O, B> {
}
}

#[derive(Serialize, Deserialize)]
#[serde(bound = "<O::State<1> as Record>::Item<S>: Serialize + serde::de::DeserializeOwned")]
pub enum SimpleOptimizerRecordItem<O: SimpleOptimizer<B>, B: Backend, S: RecordSettings> {
Rank1(<O::State<1> as Record>::Item<S>),
Rank2(<O::State<2> as Record>::Item<S>),
Expand Down
4 changes: 1 addition & 3 deletions burn-core/src/record/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};

/// Trait to define a family of types which can be recorded using any [settings](RecordSettings).
pub trait Record: Send + Sync {
type Item<S: RecordSettings>;
type Item<S: RecordSettings>: Serialize + DeserializeOwned;

/// Convert the current record into the corresponding item that follows the given [settings](RecordSettings).
fn into_item<S: RecordSettings>(self) -> Self::Item<S>;
Expand All @@ -20,7 +20,6 @@ pub trait Record: Send + Sync {
where
Self: Sized,
S: RecordSettings,
Self::Item<S>: Serialize + DeserializeOwned,
{
let metadata = BurnMetadata::new(
core::any::type_name::<S::FloatElem>().to_string(),
Expand All @@ -40,7 +39,6 @@ pub trait Record: Send + Sync {
where
Self: Sized,
S: RecordSettings,
Self::Item<S>: Serialize + DeserializeOwned,
{
let record: BurnRecord<Self::Item<S>> =
RecorderType::<S>::load(args.clone()).map_err(|err| {
Expand Down
14 changes: 9 additions & 5 deletions burn-core/src/record/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{Record, RecordSettings};
use crate::module::{Param, ParamId, State};
use alloc::vec::Vec;
use burn_tensor::{DataSerialize, Element};
use hashbrown::HashMap;
use std::collections::HashMap;

impl Record for () {
type Item<S: RecordSettings> = ();
Expand Down Expand Up @@ -36,15 +36,19 @@ impl<T: Record> Record for Option<T> {
}
}

impl<const N: usize, T: Record> Record for [T; N] {
type Item<S: RecordSettings> = [T::Item<S>; N];
impl<const N: usize, T: Record + core::fmt::Debug> Record for [T; N] {
type Item<S: RecordSettings> = Vec<T::Item<S>>;

fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
self.map(Record::into_item)
self.map(Record::into_item).into_iter().collect()
}

fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
item.map(Record::from_item)
item.into_iter()
.map(Record::from_item)
.collect::<Vec<_>>()
.try_into()
.expect(format!("An arrar of size {N}").as_str())
}
}

Expand Down
5 changes: 4 additions & 1 deletion burn-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ pub fn module_derive(input: TokenStream) -> TokenStream {
#[proc_macro_derive(Record)]
pub fn record_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
record_derive_impl(&input)
let gen = record_derive_impl(&input);

// panic!("{}", gen);
gen
}

#[proc_macro_derive(Config, attributes(config))]
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/checkpoint/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
log::info!("Saving checkpoint {} to {}", epoch, file_path);

record
.record(file_path.into())
.record::<S>(file_path.into())
.map_err(CheckpointerError::RecorderError)?;

if self.num_keep > epoch {
Expand All @@ -67,7 +67,7 @@ where
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
let record = R::load(file_path.into()).map_err(CheckpointerError::RecorderError)?;
let record = R::load::<S>(file_path.into()).map_err(CheckpointerError::RecorderError)?;

Ok(record)
}
Expand Down
16 changes: 8 additions & 8 deletions burn-train/src/learner/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::module::{ADModule, Module, State};
use burn_core::module::{ADModule, Module};
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::{ADBackend, Backend};
use burn_core::tensor::backend::ADBackend;

/// Learner struct encapsulating all components necessary to train a Neural Network model.
///
Expand All @@ -19,13 +19,13 @@ where
pub(super) callback: Box<dyn LearnerCallback<TO, VO>>,
pub(super) checkpoint: Option<usize>,
pub(super) checkpointer_model: CheckpointModel<M, B>,
pub(super) checkpointer_optimizer: CheckpointOptim<B>,
pub(super) checkpointer_optimizer: CheckpointOptim<O, M, B>,
pub(super) grad_accumulation: Option<usize>,
pub(super) devices: Vec<B::Device>,
}

type CheckpointModel<M, B> = Option<Box<dyn Checkpointer<<M as Module<B>>::Record>>>;
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<State<<B as Backend>::FloatElem>>>>;
type CheckpointOptim<O, M, B> = Option<Box<dyn Checkpointer<<O as Optimizer<M, B>>::Record>>>;

impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
where
Expand All @@ -39,7 +39,7 @@ where
model: &M,
optim: &O,
checkpointer_model: &CheckpointModel<M, B>,
checkpointer_optimizer: &CheckpointOptim<B>,
checkpointer_optimizer: &CheckpointOptim<O, M, B>,
epoch: usize,
) {
if let Some(checkpointer) = &checkpointer_model {
Expand All @@ -48,7 +48,7 @@ where
.unwrap();
}
if let Some(checkpointer) = &checkpointer_optimizer {
checkpointer.save(epoch, optim.state(model)).unwrap();
checkpointer.save(epoch, optim.to_record()).unwrap();
}
}

Expand All @@ -59,8 +59,8 @@ where
}

if let Some(checkpointer) = &self.checkpointer_optimizer {
let state = checkpointer.restore(epoch).unwrap();
self.optim.load(&self.model, &state).unwrap();
let record = checkpointer.restore(epoch).unwrap();
self.optim = self.optim.load_record(record);
}

self
Expand Down
Loading