Skip to content

Commit

Permalink
feat: basic setup
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Aug 21, 2022
1 parent 0f38096 commit 39be42b
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tch = ["burn-tensor/tch"]
ndarray = ["burn-tensor/ndarray"]

[dependencies]
num-traits = "0.2"
burn-tensor = { path = "./burn-tensor", version = "0.1.0", default-features = false }
burn-derive = { path = "./burn-derive", version = "0.1.0" }
rand = "0.8"
Expand Down
113 changes: 113 additions & 0 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use burn::module::{Forward, Module, Param};
use burn::nn;
use burn::optim::SGDOptimizer;
use burn::tensor::af::relu;
use burn::tensor::back::{ad, Backend};
use burn::tensor::losses::cross_entropy_with_logits;
use burn::tensor::{Distribution, Shape, Tensor};

#[derive(Module, Debug)]
struct Model<B>
where
B: Backend,
{
mlp: Param<MLP<B>>,
input: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
}

#[derive(Module, Debug)]
struct MLP<B>
where
B: Backend,
{
linears: Param<Vec<nn::Linear<B>>>,
}

impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for MLP<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;

for linear in self.linears.iter() {
x = linear.forward(x);
x = relu(&x);
}

x
}
}

impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Model<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;

x = self.input.forward(x);
x = self.mlp.forward(x);
x = self.output.forward(x);

x
}
}

impl<B: Backend> MLP<B> {
fn new(dim: usize, num_layers: usize) -> Self {
let mut linears = Vec::with_capacity(num_layers);

for _ in 0..num_layers {
let config = nn::LinearConfig {
d_input: dim,
d_output: dim,
bias: true,
};
let linear = nn::Linear::new(&config);
linears.push(linear);
}

Self {
linears: Param::new(linears),
}
}
}
impl<B: Backend> Model<B> {
fn new(d_input: usize, d_hidden: usize, num_layers: usize, num_classes: usize) -> Self {
let mlp = MLP::new(d_hidden, num_layers);
let config_input = nn::LinearConfig {
d_input,
d_output: d_hidden,
bias: true,
};
let config_output = nn::LinearConfig {
d_input: d_hidden,
d_output: num_classes,
bias: true,
};
let output = nn::Linear::new(&config_output);
let input = nn::Linear::new(&config_input);

Self {
mlp: Param::new(mlp),
output: Param::new(output),
input: Param::new(input),
}
}
}

fn run<B: ad::Backend>() {
let mut model: Model<B> = Model::new(768, 256, 2, 10);
let mut optim: SGDOptimizer<B> = SGDOptimizer::new(5.0e-4);

let input = Tensor::<B, 2>::random(Shape::new([32, 768]), Distribution::Standard);
let targets = Tensor::<B, 2>::random(Shape::new([32, 10]), Distribution::Standard);

let output = model.forward(input);
let loss = cross_entropy_with_logits(&output, &targets);
let grads = loss.backward();

model.update_params(&grads, &mut optim);

println!("loss {:?}", loss.to_data());
}

fn main() {
run::<ad::Tch<f32>>();
}
50 changes: 50 additions & 0 deletions src/module/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,53 @@ impl<M: Module> Param<M> {
self.value.load_from_parent(name, state);
}
}

impl<M: Module> Param<Vec<M>> {
pub fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.value.iter() {
num_params += module.num_params();
}

num_params
}

pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
M::Backend: back::ad::Backend,
{
for module in self.value.iter_mut() {
module.update_params(grads, optim);
}
}

pub fn devices(&self) -> Vec<<M::Backend as back::Backend>::Device> {
let mut devices = Vec::new();
for module in self.value.iter() {
devices.append(&mut module.devices());
}
devices
}

pub fn to_device(&mut self, device: <M::Backend as back::Backend>::Device) {
for module in self.value.iter_mut() {
module.to_device(device);
}
}

pub fn state(&self, name: &str) -> State<M::Backend>
where
<M::Backend as back::Backend>::Elem: Serialize,
<M::Backend as back::Backend>::Elem: DeserializeOwned,
{
todo!();
}

pub fn load_from_parent(&mut self, name: &str, state: &State<M::Backend>)
where
<M::Backend as back::Backend>::Elem: Serialize,
<M::Backend as back::Backend>::Elem: DeserializeOwned,
{
todo!();
}
}
21 changes: 12 additions & 9 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ use crate::module::Module;
use crate::module::{Forward, Param};
use crate::tensor::back::Backend;
use crate::tensor::{Distribution, Shape, Tensor};
use num_traits::FromPrimitive;
use std::ops::Deref;

pub struct LinearConfig {
d_input: usize,
d_output: usize,
bias: bool,
pub d_input: usize,
pub d_output: usize,
pub bias: bool,
}

#[derive(Module, Debug)]
Expand All @@ -23,10 +24,12 @@ where

impl<B: Backend> Linear<B> {
pub fn new(config: &LinearConfig) -> Self {
let weight = Tensor::random(
Shape::new([config.d_input, config.d_output]),
Distribution::Standard,
// Glorot init
let distribution = Distribution::Uniform(
<B as Backend>::Elem::from_f64(-1.0 / f64::sqrt(config.d_input as f64)).unwrap(),
<B as Backend>::Elem::from_f64(1.0 / f64::sqrt(config.d_input as f64)).unwrap(),
);
let weight = Tensor::random(Shape::new([config.d_input, config.d_output]), distribution);
let bias = match config.bias {
true => Some(Tensor::zeros(Shape::new([config.d_output]))),
false => None,
Expand All @@ -39,9 +42,9 @@ impl<B: Backend> Linear<B> {
}
}

impl<B: Backend, const D: usize> Forward<&Tensor<B, D>, Tensor<B, D>> for Linear<B> {
fn forward(&self, input: &Tensor<B, D>) -> Tensor<B, D> {
let output = self.weight.unsqueeze().matmul(input);
impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for Linear<B> {
fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let output = input.matmul(&self.weight.unsqueeze());

match self.bias.deref() {
Some(bias) => output + bias.unsqueeze(),
Expand Down
8 changes: 8 additions & 0 deletions src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@ use crate::optim::Optimizer;
use crate::tensor::back;
use crate::tensor::Gradients;
use crate::tensor::Tensor;
use num_traits::cast::FromPrimitive;

pub struct SGDOptimizer<B: back::ad::Backend> {
learning_rate: <B::InnerBackend as back::Backend>::Elem,
}

impl<B: back::ad::Backend> SGDOptimizer<B> {
pub fn new(learning_rate: f64) -> Self {
let learning_rate =
<B::InnerBackend as back::Backend>::Elem::from_f64(learning_rate).unwrap();
Self { learning_rate }
}
}
impl<B: back::ad::Backend> Optimizer<B> for SGDOptimizer<B> {
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients) {
let grad = tensor.grad(&grads).unwrap();
Expand Down

0 comments on commit 39be42b

Please sign in to comment.