Skip to content

Commit

Permalink
Refactor/burn core (tracel-ai#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jan 2, 2023
1 parent 2f179f1 commit 54c30ab
Show file tree
Hide file tree
Showing 95 changed files with 212 additions and 340 deletions.
24 changes: 21 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,32 @@ jobs:
crate: burn-ndarray
secrets: inherit

publish-burn:

publish-burn-core:
uses: burn-rs/burn/.github/workflows/publish-template.yml@main
needs:
- publish-burn-derive
- publish-burn-dataset
- publish-burn-tensor
- publish-burn-autodiff
- publish-burn-dataset
- publish-burn-derive
- publish-burn-ndarray
with:
crate: burn-core
secrets: inherit

publish-burn-train:
uses: burn-rs/burn/.github/workflows/publish-template.yml@main
needs:
- publish-burn-core
with:
crate: burn-train
secrets: inherit

publish-burn:
uses: burn-rs/burn/.github/workflows/publish-template.yml@main
needs:
- publish-burn-core
- publish-burn-train
with:
crate: burn
secrets: inherit
9 changes: 7 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ jobs:
with:
crate: burn-autodiff

test-burn:
test-burn-core:
uses: burn-rs/burn/.github/workflows/test-template.yml@main
with:
crate: burn
crate: burn-core

test-burn-train:
uses: burn-rs/burn/.github/workflows/test-template.yml@main
with:
crate: burn-train
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[workspace]
members = [
"burn",
"burn-core",
"burn-train",
"burn-derive",
"burn-tensor",
"burn-tensor-testgen",
Expand Down
40 changes: 40 additions & 0 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[package]
name = "burn-core"
version = "0.4.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "BURN: Burn Unstoppable Rusty Neurons"
repository = "https://github.com/burn-rs/burn-core"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
categories = ["science"]
license = "MIT/Apache-2.0"
edition = "2021"

[features]
default = []
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

[dependencies]
burn-tensor = { version = "0.4.0", path = "../burn-tensor" }
burn-autodiff = { version = "0.4.0", path = "../burn-autodiff" }
burn-dataset = { version = "0.4.0", path = "../burn-dataset", default-features = false }
burn-derive = { version = "0.4.0", path = "../burn-derive" }

# Utilities
derive-new = "0.5.9"
rand = "0.8.5"
log = "0.4.17"

# Serialize Deserialize
serde = { version = "1.0.151", features = ["derive"] }
serde_json = "1.0.91"
flate2 = "1.0.25"

# Parameter & Optimization
nanoid = "0.4.0"

[dev-dependencies]
burn-dataset = { version = "0.4.0", path = "../burn-dataset", features = [
"fake",
] }
burn-ndarray = { version = "0.4.0", path = "../burn-ndarray" }
1 change: 1 addition & 0 deletions burn-core/LICENSE-APACHE
1 change: 1 addition & 0 deletions burn-core/LICENSE-MIT
6 changes: 6 additions & 0 deletions burn-core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Burn Core

This crate should be used with [burn](https://github.com/burn-rs/burn).

[![Current Crates.io Version](https://img.shields.io/crates/v/burn-core.svg)](https://crates.io/crates/burn-core)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-core/blob/master/README.md)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions burn-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#[macro_use]
extern crate derive_new;

pub mod config;
pub mod data;
pub mod module;
pub mod nn;
pub mod optim;
pub mod tensor;

#[cfg(test)]
pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;
#[cfg(test)]
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
13 changes: 9 additions & 4 deletions burn/src/module/base.rs → burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ use burn_tensor::Tensor;
/// necessary to optimize and train the module on any backend.
///
/// ```rust
/// use burn::nn;
/// use burn::module::{Param, Module};
/// use burn::tensor::Tensor;
/// use burn::tensor::backend::Backend;
/// // Not necessary when using the burn crate directly.
/// use burn_core as burn;
///
/// use burn::{
/// nn,
/// module::{Param, Module},
/// tensor::Tensor,
/// tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions burn/src/optim/mod.rs → burn-core/src/optim/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
pub(super) mod visitor;

pub mod decay;
pub mod momentum;

mod adam;
mod base;
mod grad_accum;
mod sgd;
mod visitor;

pub use adam::*;
pub use base::*;
pub use grad_accum::*;
pub use sgd::*;
pub use visitor::*;
File renamed without changes.
File renamed without changes.
12 changes: 7 additions & 5 deletions burn/src/optim/visitor.rs → burn-core/src/optim/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,35 @@ use burn_tensor::{
Tensor,
};

/// Data type that contains gradients for a given backend.
pub type GradientsParams<B> = TensorContainer<<B as ADBackend>::InnerBackend, ParamId>;

#[derive(new)]
pub struct GradientsRegister<'a, B: ADBackend, O> {
pub(crate) struct GradientsRegister<'a, B: ADBackend, O> {
optimizer: &'a O,
state: &'a mut StateNamed<B::Elem>,
}

#[derive(new)]
pub struct GradientsLoader<'a, B: ADBackend, O> {
pub(crate) struct GradientsLoader<'a, B: ADBackend, O> {
optimizer: &'a mut O,
state: &'a StateNamed<B::Elem>,
}

#[derive(new)]
pub struct GradientsParamsConverter<'a, B: ADBackend> {
pub(crate) struct GradientsParamsConverter<'a, B: ADBackend> {
grads: B::Gradients,
grads_params: &'a mut TensorContainer<B::InnerBackend, ParamId>,
}

#[derive(new)]
pub struct ModuleTensorUpdater<'a, B: ADBackend, O> {
pub(crate) struct ModuleTensorUpdater<'a, B: ADBackend, O> {
optimizer: &'a mut O,
grads: GradientsParams<B>,
}

#[derive(new)]
pub struct GradientsParamsChangeDevice<'a, B: ADBackend> {
pub(crate) struct GradientsParamsChangeDevice<'a, B: ADBackend> {
device: B::Device,
grads: &'a mut GradientsParams<B>,
}
Expand Down Expand Up @@ -77,6 +78,7 @@ impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, B> {
}
}

/// Update the device of each tensor gradients.
pub fn to_device_grads<M: ADModule>(
grads: &mut GradientsParams<M::ADBackend>,
device: <M::Backend as Backend>::Device,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use burn::config::Config;
use burn_core as burn;

#[derive(Config, Debug, PartialEq, Eq)]
pub struct TestEmptyStructConfig {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn::module::{Module, Param};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution, Shape, Tensor};
use burn_core as burn;

pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;

Expand Down
29 changes: 29 additions & 0 deletions burn-train/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[package]
name = "burn-train"
version = "0.4.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "Training crate for burn"
repository = "https://github.com/burn-rs/burn"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
categories = ["science"]
license = "MIT/Apache-2.0"
edition = "2021"

[dependencies]
burn-core = { version = "0.4.0", path = "../burn-core" }

# Console
indicatif = "0.17.2"
log4rs = "1.2.0"
log = "0.4.17"

# Metrics
nvml-wrapper = "0.8.0"
textplots = "0.8.0"
rgb = "0.8.34"
terminal_size = "0.2.3"

# Utilities
derive-new = "0.5.9"
serde = { version = "1.0.151", features = ["derive"] }
1 change: 1 addition & 0 deletions burn-train/LICENSE-APACHE
1 change: 1 addition & 0 deletions burn-train/LICENSE-MIT
6 changes: 6 additions & 0 deletions burn-train/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Burn Train

This crate should be used with [burn](https://github.com/burn-rs/burn).

[![Current Crates.io Version](https://img.shields.io/crates/v/burn-train.svg)](https://crates.io/crates/burn-train)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-train/blob/master/README.md)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::data::dataloader::Progress;
use burn_core::data::dataloader::Progress;

pub trait LearnerCallback<T, V>: Send {
fn on_train_item(&mut self, _item: LearnerItem<T>) {}
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{Checkpointer, CheckpointerError};
use crate::module::State;
use burn_tensor::Element;
use burn_core::module::State;
use burn_core::tensor::Element;
use std::sync::{mpsc, Arc};

enum Message<E> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::module::{State, StateError};
use burn_core::module::{State, StateError};

#[derive(Debug)]
pub enum CheckpointerError {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{Checkpointer, CheckpointerError};
use crate::module::State;
use burn_tensor::Element;
use burn_core::module::State;
use burn_core::tensor::Element;

pub struct FileCheckpointer<P> {
directory: String,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::module::ADModule;
use crate::optim::Optimizer;
use crate::tensor::backend::Backend;
use crate::train::checkpoint::Checkpointer;
use crate::train::LearnerCallback;
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;

/// Learner struct encapsulating all components necessary to train a Neural Network model.
///
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use super::log::update_log_file;
use super::Learner;
use crate::module::ADModule;
use crate::optim::Optimizer;
use crate::train::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::train::logger::FileMetricLogger;
use crate::train::metric::dashboard::cli::CLIDashboardRenderer;
use crate::train::metric::dashboard::Dashboard;
use crate::train::metric::{Adaptor, Metric, Numeric};
use crate::train::AsyncTrainerCallback;
use burn_tensor::backend::ADBackend;
use burn_tensor::Element;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::logger::FileMetricLogger;
use crate::metric::dashboard::cli::CLIDashboardRenderer;
use crate::metric::dashboard::Dashboard;
use crate::metric::{Adaptor, Metric, Numeric};
use crate::AsyncTrainerCallback;
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::ADBackend;
use burn_core::tensor::Element;
use std::sync::Arc;

/// Struct to configure and create a [learner](Learner).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::tensor::backend::Backend;
use crate::train::metric::{AccuracyInput, Adaptor, LossInput};
use burn_tensor::Tensor;
use crate::metric::{AccuracyInput, Adaptor, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::Tensor;

/// Simple classification output adapted for multiple metrics.
#[derive(new)]
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::{
data::dataloader::DataLoaderIterator,
module::ADModule,
train::{TrainOutput, TrainStep},
use crate::{TrainOutput, TrainStep};
use burn_core::{
data::dataloader::DataLoaderIterator, module::ADModule, tensor::backend::ADBackend,
};
use burn_tensor::backend::ADBackend;
use std::sync::mpsc::{Receiver, Sender};
use std::thread::spawn;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::Learner;
use crate::data::dataloader::DataLoader;
use crate::module::ADModule;
use crate::optim::visitor::{convert_grads, to_device_grads, GradientsParams};
use crate::optim::{GradientsAccumulator, Optimizer};
use crate::train::train::MultiDevicesTrainStep;
use crate::train::LearnerItem;
use burn_tensor::backend::ADBackend;
use crate::train::MultiDevicesTrainStep;
use crate::LearnerItem;
use burn_core::data::dataloader::DataLoader;
use burn_core::module::ADModule;
use burn_core::optim::{
convert_grads, to_device_grads, GradientsAccumulator, GradientsParams, Optimizer,
};
use burn_core::tensor::backend::ADBackend;
use std::sync::Arc;

pub struct TrainOutput<B: ADBackend, TO> {
Expand Down
3 changes: 3 additions & 0 deletions burn/src/train/mod.rs → burn-train/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[macro_use]
extern crate derive_new;

pub mod checkpoint;
pub mod logger;
pub mod metric;
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::train::metric::MetricEntry;

use super::{AsyncLogger, FileLogger, Logger};
use crate::metric::MetricEntry;
use std::collections::HashMap;

pub trait MetricLogger: Send {
Expand Down
File renamed without changes.
Loading

0 comments on commit 54c30ab

Please sign in to comment.