From 5191678b7a6c60bff3bc90b50a8f6cd64c62c4aa Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:02:25 -0500 Subject: [PATCH 1/2] Update book example guide --- burn-book/src/basic-workflow/model.md | 36 +++++++++++++++++++- examples/guide/README.md | 19 +++++++++-- examples/guide/examples/guide.rs | 2 +- examples/guide/src/bin/infer.rs | 20 +++++++++++ examples/guide/src/bin/print.rs | 11 ++++++ examples/guide/src/{main.rs => bin/train.rs} | 22 +++++++----- examples/guide/src/inference.rs | 4 +-- 7 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 examples/guide/src/bin/infer.rs create mode 100644 examples/guide/src/bin/print.rs rename examples/guide/src/{main.rs => bin/train.rs} (64%) diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index 9868bc8145..28d0682a4b 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -192,7 +192,7 @@ Next, we need to instantiate the model for training. # linear2: Linear, # activation: Relu, # } -# +# #[derive(Config, Debug)] pub struct ModelConfig { num_classes: usize, @@ -217,6 +217,40 @@ impl ModelConfig { } ``` + +At a glance, you can view the model configuration by printing the model instance: + +```rust , ignore +use burn::backend::Wgpu; +use guide::model::ModelConfig; + +fn main() { + type MyBackend = Wgpu; + + let device = Default::default(); + let model = ModelConfig::new(10, 512).init::(&device); + + println!("{}", model); +} +``` + +Output: + +```rust , ignore +Model { + conv1: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80} + conv2: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168} + pool: AdaptiveAvgPool2d {output_size: [8, 8]} + dropout: Dropout {prob: 0.5} + linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800} + linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130} + activation: Relu + params: 531178 +} +``` + + +
🦀 References diff --git a/examples/guide/README.md b/examples/guide/README.md index 06fedaa8eb..0a6e6392a8 100644 --- a/examples/guide/README.md +++ b/examples/guide/README.md @@ -4,6 +4,21 @@ This example corresponds to the [book's guide](https://burn.dev/book/basic-workf ## Example Usage + +### Training + +```sh +cargo run --bin train --release +``` + +### Inference + +```sh +cargo run --bin infer --release +``` + +### Print the model + ```sh -cargo run --example guide -``` \ No newline at end of file +cargo run --bin print --release +``` diff --git a/examples/guide/examples/guide.rs b/examples/guide/examples/guide.rs index 0ddc70e460..bbac49635c 100644 --- a/examples/guide/examples/guide.rs +++ b/examples/guide/examples/guide.rs @@ -10,7 +10,7 @@ use std::process::Command; fn main() { Command::new("cargo") - .args(["run", "--bin", "guide"]) + .args(["run", "--bin", "train", "--release"]) .status() .expect("guide example should run"); } diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs new file mode 100644 index 0000000000..3c64879bc5 --- /dev/null +++ b/examples/guide/src/bin/infer.rs @@ -0,0 +1,20 @@ +use burn::{backend::Wgpu, data::dataset::Dataset}; +use guide::inference; + +fn main() { + type MyBackend = Wgpu; + + let device = burn::backend::wgpu::WgpuDevice::default(); + + // All the training artifacts are saved in this directory + let artifact_dir = "/tmp/guide"; + + // Infer the model + inference::infer::( + artifact_dir, + device, + burn::data::dataset::vision::MnistDataset::test() + .get(42) + .unwrap(), + ); +} diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs new file mode 100644 index 0000000000..9432aa93a4 --- /dev/null +++ b/examples/guide/src/bin/print.rs @@ -0,0 +1,11 @@ +use burn::backend::Wgpu; +use guide::model::ModelConfig; + +fn main() { + type MyBackend = Wgpu; + + let device = Default::default(); + let model = ModelConfig::new(10, 512).init::(&device); + + println!("{}", model); +} diff --git a/examples/guide/src/main.rs b/examples/guide/src/bin/train.rs similarity index 64% rename from examples/guide/src/main.rs rename to examples/guide/src/bin/train.rs index a89e857212..04f1f44146 100644 --- a/examples/guide/src/main.rs +++ b/examples/guide/src/bin/train.rs @@ -1,27 +1,33 @@ -mod data; -mod inference; -mod model; -mod training; - -use crate::{model::ModelConfig, training::TrainingConfig}; use burn::{ backend::{Autodiff, Wgpu}, data::dataset::Dataset, optim::AdamConfig, }; +use guide::{ + inference, + model::ModelConfig, + training::{self, TrainingConfig}, +}; fn main() { type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; + // Create a default Wgpu device let device = burn::backend::wgpu::WgpuDevice::default(); + + // All the training artifacts will be saved in this directory let artifact_dir = "/tmp/guide"; - crate::training::train::( + + // Train the model + training::train::( artifact_dir, TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); - crate::inference::infer::( + + // Infer the model + inference::infer::( artifact_dir, device, burn::data::dataset::vision::MnistDataset::test() diff --git a/examples/guide/src/inference.rs b/examples/guide/src/inference.rs index 23c41e1181..9c7477f3f4 100644 --- a/examples/guide/src/inference.rs +++ b/examples/guide/src/inference.rs @@ -7,10 +7,10 @@ use burn::{ pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) { let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) - .expect("Config should exist for the model"); + .expect("Config should exist for the model; run train first"); let record = CompactRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) - .expect("Trained model should exist"); + .expect("Trained model should exist; run train first"); let model: Model = config.model.init(&device).load_record(record); From a583539e610a9b67fc197e498626713b9433893d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:03:11 -0500 Subject: [PATCH 2/2] Update Module book section on module display --- burn-book/src/building-blocks/module.md | 73 ++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 12c3e6bdbc..fe94810be6 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want. These methods are available for all modules. | Burn API | PyTorch Equivalent | -|-----------------------------------------|------------------------------------------| +| --------------------------------------- | ---------------------------------------- | | `module.devices()` | N/A | | `module.fork(device)` | Similar to `module.to(device).detach()` | | `module.to_device(device)` | `module.to(device)` | @@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif autodiff support. | Burn API | PyTorch Equivalent | -|------------------|--------------------| +| ---------------- | ------------------ | | `module.valid()` | `module.eval()` | ## Visitor & Mapper @@ -96,7 +96,62 @@ pub trait ModuleVisitor { /// Module mapper trait. pub trait ModuleMapper { /// Map a tensor in the module. - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; + fn map(&mut self, id: &ParamId, tensor: Tensor) -> + Tensor; +} +``` + +## Module Display + +Burn provides a simple way to display the structure of a module and its configuration at a glance. +You can print the module to see its structure, which is useful for debugging and tracking changes +across different versions of a module. (See the print output of the +[Basic Workflow Model](../basic-workflow/model.md) example.) + +To customize the display of a module, you can implement the `ModuleDisplay` trait for your module. +This will change the default display settings for the module and its children. Note that +`ModuleDisplay` is automatically implemented for all modules, but you can override it to customize +the display by annotating the module with `#[module(custom_display)]`. + +```rust +#[derive(Module, Debug)] +#[module(custom_display)] +pub struct PositionWiseFeedForward { + linear_inner: Linear, + linear_outer: Linear, + dropout: Dropout, + gelu: Gelu, +} + +impl ModuleDisplay for PositionWiseFeedForward { + /// Custom settings for the display of the module. + /// If `None` is returned, the default settings will be used. + fn custom_settings(&self) -> Option { + DisplaySettings::new() + // Will show all attributes (default is false) + .with_show_all_attributes(false) + // Will show each attribute on a new line (default is true) + .with_new_line_after_attribute(true) + // Will show the number of parameters (default is true) + .with_show_num_parameters(true) + // Will indent by 2 spaces (default is 2) + .with_indentation_size(2) + // Will show the parameter ID (default is false) + .with_show_param_id(false) + // Convenience method to wrap settings in Some() + .optional() + } + + /// Custom content to be displayed. + /// If `None` is returned, the default content will be used + /// (all attributes of the module) + fn custom_content(&self, content: Content) -> Option { + content + .add("linear_inner", &self.linear_inner) + .add("linear_outer", &self.linear_outer) + .add("anything", "anything_else") + .optional() + } } ``` @@ -107,7 +162,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### General | Burn API | PyTorch Equivalent | -|----------------|-----------------------------------------------| +| -------------- | --------------------------------------------- | | `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | | `Dropout` | `nn.Dropout` | | `Embedding` | `nn.Embedding` | @@ -125,7 +180,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Convolutions | Burn API | PyTorch Equivalent | -|-------------------|----------------------| +| ----------------- | -------------------- | | `Conv1d` | `nn.Conv1d` | | `Conv2d` | `nn.Conv2d` | | `ConvTranspose1d` | `nn.ConvTranspose1d` | @@ -134,7 +189,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Pooling | Burn API | PyTorch Equivalent | -|---------------------|------------------------| +| ------------------- | ---------------------- | | `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` | | `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` | | `AvgPool1d` | `nn.AvgPool1d` | @@ -145,7 +200,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### RNNs | Burn API | PyTorch Equivalent | -|------------------|------------------------| +| ---------------- | ---------------------- | | `Gru` | `nn.GRU` | | `Lstm`/`BiLstm` | `nn.LSTM` | | `GateController` | _No direct equivalent_ | @@ -153,7 +208,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Transformer | Burn API | PyTorch Equivalent | -|----------------------|-------------------------| +| -------------------- | ----------------------- | | `MultiHeadAttention` | `nn.MultiheadAttention` | | `TransformerDecoder` | `nn.TransformerDecoder` | | `TransformerEncoder` | `nn.TransformerEncoder` | @@ -163,7 +218,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Loss | Burn API | PyTorch Equivalent | -|--------------------|-----------------------| +| ------------------ | --------------------- | | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` |