Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Next Next commit
Wip
  • Loading branch information
nathanielsimard committed Apr 4, 2023
commit 46b74e61e9bfbf6deb309529157c4356cd86a3e7
38 changes: 35 additions & 3 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::{self as burn, module::ADModule};
use crate::{self as burn, module::ADModule, record::Record};

use super::{
decay::{WeightDecay, WeightDecayConfig},
load_state_gradients, register_state_gradients, GradientsParams,
load_state_gradients, register_state_gradients, GradientsParams, SimpleOptimizer,
};
use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::{backend::ADBackend, Tensor};
use burn_tensor::ElementConversion;
use burn_tensor::{backend::Backend, ElementConversion};

#[derive(Config)]
pub struct AdamConfig {
Expand All @@ -34,6 +34,20 @@ pub struct Adam<B: ADBackend> {
weight_decay: Option<WeightDecay<B>>,
}

impl<B: Backend> SimpleOptimizer<B> for usize {
type State<const D: usize> = Tensor<B, D>;

fn step<const D: usize>(
&self,
id: &ParamId,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
todo!()
}
}

impl<B: ADBackend> Adam<B> {
pub fn new(config: &AdamConfig) -> Self {
Self {
Expand Down Expand Up @@ -97,6 +111,24 @@ impl<M: ADModule<B>, B: ADBackend> Optimizer<M, B> for Adam<B> {
}
}

struct AdaptiveMomentumState<B: Backend, const D: usize> {
time: usize,
moment_1: Tensor<B, D>,
moment_2: Tensor<B, D>,
}

// impl<B: Backend, const D: usize> Record for AdaptiveMomentumState<B, D> {
// type Item<S: burn::record::RecordSettings>;
//
// fn into_item<S: burn::record::RecordSettings>(self) -> Self::Item<S> {
// todo!()
// }
//
// fn from_item<S: burn::record::RecordSettings>(item: Self::Item<S>) -> Self {
// todo!()
// }
// }

struct AdaptiveMomentum {
beta_1: f32,
beta_2: f32,
Expand Down
13 changes: 13 additions & 0 deletions burn-core/src/optim/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,22 @@ use super::visitor::{GradientsLoader, GradientsRegister};
use super::GradientsParams;

use crate::module::{ADModule, LoadingError, ParamId, State, StateNamed};
use crate::record::Record;
use crate::tensor::backend::ADBackend;
use crate::tensor::{Data, Tensor};

pub trait ModuleOptimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
type Record: Record;

fn step(&mut self, module: M, grads: B::Gradients) -> M;
fn to_record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}

pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
Expand Down
2 changes: 2 additions & 0 deletions burn-core/src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ mod grad_accum;
mod grads;
mod mapper;
mod sgd;
mod simple;
mod visitor;

pub use adam::*;
pub use base::*;
pub use grad_accum::*;
pub use grads::*;
pub use sgd::*;
pub use simple::*;
253 changes: 253 additions & 0 deletions burn-core/src/optim/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
use core::{any::Any, marker::PhantomData};

use crate::{
module::{ADModule, ModuleMapper, ParamId},
record::{Record, RecordSettings},
};
use burn_tensor::{
backend::{ADBackend, Backend},
Tensor,
};
use hashbrown::HashMap;

use super::ModuleOptimizer;

/// Simple optimizer is a more opinionated trait where the state can be generic over the
/// dimension D and implements Record. This allows for simpler optimizer implementations where they
/// don't have to handle missing gradients, loading and exporting records, and navigate the
/// module parameter structure.
pub trait SimpleOptimizer<B>: Send + Sync
where
B: Backend,
{
type State<const D: usize>: Record + Clone;

fn step<const D: usize>(
&self,
id: &ParamId,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);
}

pub struct SimpleModuleOptimizer<O, M, B>
where
O: SimpleOptimizer<B::InnerBackend>,
M: ADModule<B>,
B: ADBackend,
{
optim: O,
records: HashMap<ParamId, SimpleOptimizerRecord<O, B::InnerBackend>>,
module: PhantomData<M>,
}

impl<O, B, M> ModuleOptimizer<M, B> for SimpleModuleOptimizer<O, M, B>
where
B: ADBackend,
M: ADModule<B>,
O: SimpleOptimizer<B::InnerBackend>,
{
type Record = HashMap<ParamId, SimpleOptimizerRecord<O, B::InnerBackend>>;

fn step(&mut self, module: M, mut grads: <B as ADBackend>::Gradients) -> M {
let mut mapper =
SimpleModuleOptimizerMapper::<M, B, O>::new(&self.optim, &mut self.records, &mut grads);
module.map(&mut mapper)
}

fn to_record(&self) -> Self::Record {
self.records.clone()
}

fn load_record(mut self, record: Self::Record) -> Self {
self.records = record;
self
}
}

#[derive(new)]
pub struct SimpleModuleOptimizerMapper<'a, M, B, O>
where
M: ADModule<B>,
B: ADBackend,
O: SimpleOptimizer<B::InnerBackend>,
{
optimizer: &'a O,
records: &'a mut HashMap<ParamId, SimpleOptimizerRecord<O, B::InnerBackend>>,
grads: &'a mut B::Gradients,
phatom: PhantomData<M>,
}

impl<'a, M, B, O> ModuleMapper<B> for SimpleModuleOptimizerMapper<'a, M, B, O>
where
M: ADModule<B>,
B: ADBackend,
O: SimpleOptimizer<B::InnerBackend>,
{
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let grad = tensor.grad(self.grads);

if let Some(grad) = grad {
let (key, record) = self.records.remove_entry(id).unzip();
let (tensor, state) = self.optimizer.step(
id,
tensor.inner(),
grad,
record.map(|record| record.into_state()),
);

if let Some(state) = state {
self.records.insert(
key.unwrap_or_else(|| id.clone()),
SimpleOptimizerRecord::from_state(state),
);
}

return Tensor::from_inner(tensor);
}

tensor
}
}

pub enum SimpleOptimizerRecord<O: SimpleOptimizer<B>, B: Backend> {
Rank1(O::State<1>),
Rank2(O::State<2>),
Rank3(O::State<3>),
Rank4(O::State<4>),
Rank5(O::State<5>),
Rank6(O::State<6>),
Rank7(O::State<7>),
Rank8(O::State<8>),
}

impl<O: SimpleOptimizer<B>, B: Backend> Clone for SimpleOptimizerRecord<O, B> {
fn clone(&self) -> Self {
match self {
SimpleOptimizerRecord::Rank1(record) => SimpleOptimizerRecord::Rank1(record.clone()),
SimpleOptimizerRecord::Rank2(record) => SimpleOptimizerRecord::Rank2(record.clone()),
SimpleOptimizerRecord::Rank3(record) => SimpleOptimizerRecord::Rank3(record.clone()),
SimpleOptimizerRecord::Rank4(record) => SimpleOptimizerRecord::Rank4(record.clone()),
SimpleOptimizerRecord::Rank5(record) => SimpleOptimizerRecord::Rank5(record.clone()),
SimpleOptimizerRecord::Rank6(record) => SimpleOptimizerRecord::Rank6(record.clone()),
SimpleOptimizerRecord::Rank7(record) => SimpleOptimizerRecord::Rank7(record.clone()),
SimpleOptimizerRecord::Rank8(record) => SimpleOptimizerRecord::Rank8(record.clone()),
}
}
}

pub enum SimpleOptimizerRecordItem<O: SimpleOptimizer<B>, B: Backend, S: RecordSettings> {
Rank1(<O::State<1> as Record>::Item<S>),
Rank2(<O::State<2> as Record>::Item<S>),
Rank3(<O::State<3> as Record>::Item<S>),
Rank4(<O::State<4> as Record>::Item<S>),
Rank5(<O::State<5> as Record>::Item<S>),
Rank6(<O::State<6> as Record>::Item<S>),
Rank7(<O::State<7> as Record>::Item<S>),
Rank8(<O::State<8> as Record>::Item<S>),
}

impl<O, B> SimpleOptimizerRecord<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
{
pub fn into_state<const D: usize>(self) -> O::State<D> {
let boxed_state: Box<dyn Any> = match self {
SimpleOptimizerRecord::Rank1(s) => Box::new(s),
SimpleOptimizerRecord::Rank2(s) => Box::new(s),
SimpleOptimizerRecord::Rank3(s) => Box::new(s),
SimpleOptimizerRecord::Rank4(s) => Box::new(s),
SimpleOptimizerRecord::Rank5(s) => Box::new(s),
SimpleOptimizerRecord::Rank6(s) => Box::new(s),
SimpleOptimizerRecord::Rank7(s) => Box::new(s),
SimpleOptimizerRecord::Rank8(s) => Box::new(s),
};
let state = boxed_state
.downcast::<O::State<D>>()
.expect("Unsupported state dimension");
*state
}
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
let state: Box<dyn Any> = Box::new(state);

match D {
1 => SimpleOptimizerRecord::Rank1(*state.downcast().unwrap()),
2 => SimpleOptimizerRecord::Rank2(*state.downcast().unwrap()),
3 => SimpleOptimizerRecord::Rank3(*state.downcast().unwrap()),
4 => SimpleOptimizerRecord::Rank4(*state.downcast().unwrap()),
5 => SimpleOptimizerRecord::Rank5(*state.downcast().unwrap()),
6 => SimpleOptimizerRecord::Rank6(*state.downcast().unwrap()),
7 => SimpleOptimizerRecord::Rank7(*state.downcast().unwrap()),
8 => SimpleOptimizerRecord::Rank8(*state.downcast().unwrap()),
_ => panic!("Unsupported state dimension"),
}
}
}

impl<O, B> Record for SimpleOptimizerRecord<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
{
type Item<S: RecordSettings> = SimpleOptimizerRecordItem<O, B, S>;

fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
match self {
SimpleOptimizerRecord::Rank1(record) => {
SimpleOptimizerRecordItem::Rank1(record.into_item())
}
SimpleOptimizerRecord::Rank2(record) => {
SimpleOptimizerRecordItem::Rank2(record.into_item())
}
SimpleOptimizerRecord::Rank3(record) => {
SimpleOptimizerRecordItem::Rank3(record.into_item())
}
SimpleOptimizerRecord::Rank4(record) => {
SimpleOptimizerRecordItem::Rank4(record.into_item())
}
SimpleOptimizerRecord::Rank5(record) => {
SimpleOptimizerRecordItem::Rank5(record.into_item())
}
SimpleOptimizerRecord::Rank6(record) => {
SimpleOptimizerRecordItem::Rank6(record.into_item())
}
SimpleOptimizerRecord::Rank7(record) => {
SimpleOptimizerRecordItem::Rank7(record.into_item())
}
SimpleOptimizerRecord::Rank8(record) => {
SimpleOptimizerRecordItem::Rank8(record.into_item())
}
}
}

fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
match item {
SimpleOptimizerRecordItem::Rank1(item) => {
SimpleOptimizerRecord::Rank1(<O::State<1> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank2(item) => {
SimpleOptimizerRecord::Rank2(<O::State<2> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank3(item) => {
SimpleOptimizerRecord::Rank3(<O::State<3> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank4(item) => {
SimpleOptimizerRecord::Rank4(<O::State<4> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank5(item) => {
SimpleOptimizerRecord::Rank5(<O::State<5> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank6(item) => {
SimpleOptimizerRecord::Rank6(<O::State<6> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank7(item) => {
SimpleOptimizerRecord::Rank7(<O::State<7> as Record>::from_item(item))
}
SimpleOptimizerRecordItem::Rank8(item) => {
SimpleOptimizerRecord::Rank8(<O::State<8> as Record>::from_item(item))
}
}
}
}
23 changes: 22 additions & 1 deletion burn-core/src/record/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{Record, RecordSettings};
use crate::module::{Param, State};
use crate::module::{Param, ParamId, State};
use alloc::vec::Vec;
use burn_tensor::{DataSerialize, Element};
use hashbrown::HashMap;

impl Record for () {
type Item<S: RecordSettings> = ();
Expand Down Expand Up @@ -47,6 +48,26 @@ impl<const N: usize, T: Record> Record for [T; N] {
}
}

impl<T: Record> Record for HashMap<ParamId, T> {
type Item<S: RecordSettings> = HashMap<ParamId, T::Item<S>>;

fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
let mut items = HashMap::with_capacity(self.len());
self.into_iter().for_each(|(id, record)| {
items.insert(id, record.into_item());
});
items
}

fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
let mut record = HashMap::with_capacity(item.len());
item.into_iter().for_each(|(id, item)| {
record.insert(id, T::from_item(item));
});
record
}
}

impl<T: Element> Record for State<T> {
type Item<S: RecordSettings> = State<S::FloatElem>;

Expand Down