Skip to content

Commit

Permalink
Feat/optim (tracel-ai#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 18, 2022
1 parent 2e29e82 commit b9f8337
Show file tree
Hide file tree
Showing 25 changed files with 957 additions and 188 deletions.
4 changes: 4 additions & 0 deletions burn-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let param = Param::from_ast(ast);
let num_params_fn = param.gen_num_params_fn();
let update_params_fn = param.gen_update_params_fn();
let load_optim_state = param.gen_load_optim_state_fn();
let register_optim_state = param.gen_register_optim_state_fn();
let devices_fn = param.gen_devices_fn();
let to_device_fn = param.gen_to_device_fn();
let state_fn = param.gen_state_fn();
Expand All @@ -36,6 +38,8 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#name_fn
#num_params_fn
#update_params_fn
#load_optim_state
#register_optim_state
#devices_fn
#to_device_fn

Expand Down
36 changes: 36 additions & 0 deletions burn-derive/src/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,42 @@ impl Param {
}
}

pub fn gen_load_optim_state_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.load_optim_state(optim, state_optim);
});
}

quote! {
fn load_optim_state<O: burn::optim::Optimizer<Backend = B>>(&self, optim: &mut O, state_optim: &burn::module::StateNamed<B::Elem>)
where
B: burn::tensor::backend::ADBackend {
#body
}
}
}

pub fn gen_register_optim_state_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.register_optim_state(optim, state_optim);
});
}

quote! {
fn register_optim_state<O: burn::optim::Optimizer<Backend = B>>(&self, optim: &O, state_optim: &mut burn::module::StateNamed<B::Elem>)
where
B: burn::tensor::backend::ADBackend {
#body
}
}
}

pub fn gen_devices_fn(&self) -> TokenStream {
let mut body = quote! {
let mut devices = Vec::new();
Expand Down
42 changes: 29 additions & 13 deletions burn-tensor/src/graph/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,38 @@ use crate::{
};
use std::{any::Any, collections::HashMap, ops::Add};

#[derive(Default)]
pub struct Gradients {
grads: HashMap<String, Box<dyn Any>>,
grads: HashMap<String, Box<dyn Any + Send + Sync>>,
}

impl Gradients {
pub fn empty() -> Self {
Self {
grads: HashMap::new(),
}
}

pub fn register<T>(&mut self, node: &BackwardNode<T>)
where
T: Zeros<T> + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static,
T: std::fmt::Debug + 'static + Send + Sync,
{
let grad = node.state.grad();
self.grads.insert(node.id.clone(), Box::new(grad));
}
fn empty() -> Self {
Self {
grads: HashMap::new(),
}
}
}

pub trait AsNode<T> {
fn as_node(&self) -> &ForwardNode<T>;
}
pub fn register_any<V>(&mut self, id: String, value: V)
where
V: std::fmt::Debug + 'static + Send + Sync,
{
self.grads.insert(id, Box::new(value));
}

impl Gradients {
pub fn from<T>(node: &BackwardNode<T>) -> Self
where
T: Zeros<T> + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static,
T: std::fmt::Debug + 'static + Send + Sync,
{
let mut grads = Self::empty();
let traversal = BreadthFirstSearch::new(node);
Expand All @@ -57,4 +60,17 @@ impl Gradients {

grad.downcast_ref()
}

pub fn get<V: 'static>(&self, id: &str) -> Option<&V> {
let grad = match self.grads.get(id) {
Some(grad) => grad,
None => return None,
};

grad.downcast_ref()
}
}

pub trait AsNode<T> {
fn as_node(&self) -> &ForwardNode<T>;
}
4 changes: 2 additions & 2 deletions burn-tensor/src/graph/node/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<Out: Clone + Zeros<Out>> BackwardNode<Out> {
impl<Out> BackwardNode<Out>
where
Out: Zeros<Out> + Ones<Out> + Clone + Add<Output = Out>,
Out: std::fmt::Debug + 'static,
Out: std::fmt::Debug + 'static + Send + Sync,
{
pub fn backward(&mut self) -> Gradients {
let grad = self.state.value().ones();
Expand Down Expand Up @@ -75,7 +75,7 @@ where
impl<T> RecordedOpsParent for BackwardNode<T>
where
T: Zeros<T> + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static,
T: std::fmt::Debug + 'static + Send + Sync,
{
fn backward_step(&self) {
self.ops.backward_step(&self.state)
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/graph/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ where

impl<Lhs, Rhs, Out, Ops> BackwardRecordedOps<Out> for BackwardBinaryRecordedOps<Lhs, Rhs, Ops>
where
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static,
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static,
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
Ops: BinaryOps<Lhs, Rhs, Out> + std::fmt::Debug + 'static,
{
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
<<B as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>;

pub trait ADBackend: Backend {
type InnerBackend: Backend<Device = Self::Device>;
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Gradients;
fn grad<const D: usize>(
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/tests/tensor/ops/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ fn should_support_exp_ops() {

let data_actual = tensor.exp().into_data();

let data_expected = Data::from([[1.0, 2.7183, 7.3891], [20.0855, 54.5981, 148.4132]]);
let data_expected = Data::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
3 changes: 3 additions & 0 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
flate2 = "1.0"

# Parameter & Optimization
nanoid = "0.4"

[dev-dependencies]
burn-dataset = { path = "../burn-dataset", version = "0.1.0", features = ["fake"] }
62 changes: 46 additions & 16 deletions burn/examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
use burn::module::{Forward, Module, Param};
use burn::module::{Forward, Module, Param, State};
use burn::nn;
use burn::optim::SGDOptimizer;
use burn::tensor::activation::relu;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::momentum::MomentumConfig;
use burn::optim::{Optimizer, Sgd, SgdConfig};
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::loss::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::{AsyncLogger, CLILogger, TextPlot};
use burn::train::logger::{AsyncLogger, CLILogger};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use std::sync::Arc;

#[derive(Module, Debug)]
struct Model<B: Backend> {
mlp: Param<MLP<B>>,
mlp: Param<Mlp<B>>,
input: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
}

#[derive(Module, Debug)]
struct MLP<B: Backend> {
struct Mlp<B: Backend> {
linears: Param<Vec<nn::Linear<B>>>,
dropout: nn::Dropout,
activation: nn::ReLU,
}

impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for MLP<B> {
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Mlp<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;

for linear in self.linears.iter() {
x = linear.forward(x);
x = self.dropout.forward(x);
x = relu(&x);
x = self.activation.forward(x);
}

x
Expand Down Expand Up @@ -66,7 +68,7 @@ impl<B: Backend> Forward<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
}
}

impl<B: Backend> MLP<B> {
impl<B: Backend> Mlp<B> {
fn new(dim: usize, num_layers: usize) -> Self {
let mut linears = Vec::with_capacity(num_layers);

Expand All @@ -83,13 +85,14 @@ impl<B: Backend> MLP<B> {
Self {
linears: Param::new(linears),
dropout: nn::Dropout::new(&nn::DropoutConfig { prob: 0.3 }),
activation: nn::ReLU::new(),
}
}
}

impl<B: Backend> Model<B> {
fn new(d_input: usize, d_hidden: usize, num_layers: usize, num_classes: usize) -> Self {
let mlp = MLP::new(d_hidden, num_layers);
let mlp = Mlp::new(d_hidden, num_layers);
let config_input = nn::LinearConfig {
d_input,
d_output: d_hidden,
Expand Down Expand Up @@ -145,15 +148,38 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {

fn run<B: ADBackend>(device: B::Device) {
let batch_size = 128;
let learning_rate = 5.5e-2;
let num_epochs = 10;
let num_epochs = 15;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 1024;
let seed = 42;

let mut model: Model<B> = Model::new(784, hidden_dim, num_layers, 10);
let state_model = State::<f32>::load("/tmp/mnist_state_model").ok();
let state_optim = State::<f32>::load("/tmp/mnist_state_optim").ok();

let mut model = Model::new(784, hidden_dim, num_layers, 10);
model.to_device(device);

if let Some(state) = state_model {
println!("Loading model state");
model.load(&state.convert()).unwrap();
}

let mut optim = Sgd::new(&SgdConfig {
learning_rate: 2.5e-2,
weight_decay: Some(WeightDecayConfig { penalty: 0.01 }),
momentum: Some(MomentumConfig {
momentum: 0.9,
dampening: 0.1,
nesterov: true,
}),
});

if let Some(state) = state_optim {
println!("Loading optimizer state");
optim.load(&model, &state.convert()).unwrap();
}

println!(
"Training '{}' with {} params on backend {} {:?}",
model.name(),
Expand All @@ -162,7 +188,6 @@ fn run<B: ADBackend>(device: B::Device) {
device,
);

let optim: SGDOptimizer<B> = SGDOptimizer::new(learning_rate);
let batcher_train = Arc::new(MNISTBatcher::<B> { device });
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
let dataloader_train = DataLoaderBuilder::new(batcher_train)
Expand All @@ -179,7 +204,7 @@ fn run<B: ADBackend>(device: B::Device) {

let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
vec![
Box::new(TextPlot::new(LossMetric::new())),
Box::new(LossMetric::new()),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
],
Expand All @@ -202,7 +227,12 @@ fn run<B: ADBackend>(device: B::Device) {
learner,
);

let _learned = trainer.run(num_epochs);
let learned = trainer.run(num_epochs);
let state_model: State<f32> = learned.model.state().convert();
let state_optim: State<f32> = learned.optim.state(&learned.model).convert();

state_model.save("/tmp/mnist_state_model").unwrap();
state_optim.save("/tmp/mnist_state_optim").unwrap();
}

fn main() {
Expand Down
14 changes: 13 additions & 1 deletion burn/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::State;
use super::{State, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::{
backend::{ADBackend, Backend},
Expand Down Expand Up @@ -28,6 +28,18 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
optim: &mut O,
) where
Self::Backend: ADBackend;
fn load_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
fn register_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
fn name(&self) -> &str;
Expand Down
Loading

0 comments on commit b9f8337

Please sign in to comment.