Skip to content

Commit

Permalink
Book: Add Config, Dataset and Record building blocks (tracel-ai#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 5, 2023
1 parent a20a1a8 commit ab4d2f8
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 3 deletions.
6 changes: 3 additions & 3 deletions burn-book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
- [Module](./building-blocks/module.md)
- [Learner](./building-blocks/learner.md)
- [Metric](./building-blocks/metric.md)
- [Record]()
- [Dataset]()
- [Config](./building-blocks/config.md)
- [Record](./building-blocks/record.md)
- [Dataset](./building-blocks/dataset.md)
- [Custom Training Loops](./custom-training-loop.md)
- [Import ONNX Model](./import/onnx-model.md)
- [Advanced](./advanced/README.md)
Expand All @@ -25,4 +26,3 @@
- [Custom Optimizer]()
- [WebAssembly]()
- [No-Std]()
- [Terms and Concepts]()
74 changes: 74 additions & 0 deletions burn-book/src/building-blocks/config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Config

When writing scientific code, you normally have a lot of values that are set, and Deep Learning is
no exception. Python has the possibility to define default parameters for functions, which helps
improve the developer experience. However, this has the downside of potentially breaking your code
when upgrading to a new version, as the default values might change without your knowledge, making
debugging very challenging.

With that in mind, we came up with the Config system. It's a simple Rust derive that you can apply
to your types, allowing you to define default values with ease. Additionally, all configs can be
serialized, reducing potential bugs when upgrading versions and improving reproducibility.

```rust , ignore
#[derive(Config)]
use burn::config::Config;

#[derive(Config)]
pub struct MyModuleConfig {
d_model: usize,
d_ff: usize,
#[config(default = 0.1)]
dropout: f64,
}
```

The derive also adds useful `with_` methods for every attribute of your config, similar to a builder
pattern, along with a `save` method.

```rust
fn main() {
let config = MyModuleConfig::new(512, 2048);
println!("{}", config.d_model); // 512
println!("{}", config.d_ff); // 2048
println!("{}", config.dropout); // 0.1
let config = config.with_dropout(0.2);
println!("{}", config.dropout); // 0.2

config.save("config.json").unwrap();
}
```

## Good practices

The interest of the Config pattern is to be able to easily create instances, factoried from this
config. In that optic, initialization methods should be implemented on the config struct.

```rust
impl MyModuleConfig {
/// Create a module with random weights.
pub fn init(&self) -> MyModule {
MyModule {
linear: LinearConfig::new(self.d_model, self.d_ff).init(),
dropout: DropoutConfig::new(self.dropout).init(),
}
}

/// Create a module with a record, for inference and fine-tuning.
pub fn init_with(&self, record: MyModuleRecord<B>) -> MyModule {
MyModule {
linear: LinearConfig::new(
self.d_model,
self.d_ff,
).init_with(record.linear),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}
```

Then we could add this line to the above `main`:

```rust
let my_module = config.init()
```
80 changes: 80 additions & 0 deletions burn-book/src/building-blocks/dataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Dataset

Most deep learning training being done on datasets –with perhaps the exception of reinforcement learning–, it is essential to provide a convenient and performant API.
The dataset trait is quite similar to the dataset abstract class in PyTorch:

```rust, ignore
pub trait Dataset<I>: Send + Sync {
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
}
```

The dataset trait assumes a fixed-length set of items that can be randomly accessed in constant
time. This is a major difference from datasets that use Apache Arrow underneath to improve streaming
performance. Datasets in Burn don't assume _how_ they are going to be accessed; it's just a
collection of items.

However, you can compose multiple dataset transformations to lazily obtain what you want with zero
pre-processing, so that your training can start instantly!

## Transformation

Transformations in Burn are all lazy and modify one or multiple input datasets. The goal of these
transformations is to provide you with the necessary tools so that you can model complex data
distributions.

| Transformation | Description |
| ----------------- | ------------------------------------------------------------------------------------------------------------------------ |
| `SamplerDataset` | Samples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size. |
| `ShuffledDataset` | Maps each input index to a random index, similar to a dataset sampled without replacement. |
| `PartialDataset` | Returns a view of the input dataset with a specified range. |
| `MapperDataset` | Computes a transformation lazily on the input dataset. |
| `ComposedDataset` | Composes multiple datasets together to create a larger one without copying any data. |

## Storage

There are multiple dataset storage options available for you to choose from. The choice of the
dataset to use should be based on the dataset's size as well as its intended purpose.

| Storage | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------------------- |
| `InMemDataset` | In-memory dataset that uses a vector to store items. Well-suited for smaller datasets. |
| `SqliteDataset` | Dataset that uses SQLite to index items that can be saved in a simple SQL database file. Well-suited for larger datasets. |

## Sources

For now, there is only one dataset source available with Burn, but more to come!

### Hugging Face

You can easily import any Hugging Face dataset with Burn. We use SQLite as the storage to avoid
downloading the model each time or starting a Python process. You need to know the format of each
item in the dataset beforehand. Here's an example with the
[dbpedia dataset](https://huggingface.co/datasets/dbpedia_14).

```rust, ignore
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DbPediaItem {
pub title: String,
pub content: String,
pub label: usize,
}
fn main() {
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
.dataset("train") // The training split.
.unwrap();
}
```

We see that items must derive `serde::Serialize`, `serde::Deserialize`, `Clone`, and `Debug`, but
those are the only requirements.

**What about streaming datasets?**

There is no streaming dataset API with Burn, and this is by design! The learner struct will iterate
multiple times over the dataset and only checkpoint when done. You can consider the length of the
dataset as the number of iterations before performing checkpointing and running the validation.
There is nothing stopping you from returning different items even when called with the same `index`
multiple times.
57 changes: 57 additions & 0 deletions burn-book/src/building-blocks/record.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Record

Records are how states are saved with Burn. Compared to most other frameworks, Burn has its own
advanced saving mechanism that allows interoperability between backends with minimal possible
runtime errors. There are multiple reasons why Burn decided to create its own saving formats.

First, Rust has [serde](https://serde.rs/), which is an extremely well-developed serialization and
deserialization library that also powers the `safetensors` format developed by Hugging Face. If used
properly, all the validations are done when deserializing, which removes the need to write
validation code. Since modules in Burn are created with configurations, they can't implement
serialization and deserialization. That's why the record system was created: allowing you to save
the state of modules independently of the backend in use extremely fast while still giving you all
the flexibility possible to include any non-serializable field within your module.

**Why not use safetensors?**

`Safetensors` uses serde with the JSON file format and only supports serializing and deserializing
tensors. The record system in Burn gives you the possibility to serialize any type, which is very
useful for optimizers that save their state, but also for any non-standard, cutting-edge modeling
needs you may have. Additionally, the record system performs automatic precision conversion by using
Rust types, making it more reliable with fewer manual manipulations.

It is important to note that the `safetensors` format uses the word *safe* to distinguish itself from Pickle, which is vulnerable to Python code injection. The safety comes from a checksum mechanism that guarantees that the data suffered no corruption. On our end, the simple fact that we use Rust already ensures that no code injection is possible. To prevent any data corruption, using a recorder with Gzip compression is recommended, as it includes a checksum mechanism.

## Recorder

Recorders are independent of the backend and serialize records with precision and a format. Note
that the format can also be in-memory, allowing you to save the records directly into bytes.

| Recorder | Format | Compression |
| ---------------------- | ------------------------- | ----------- |
| DefaultFileRecorder | File - Named Message Park | Gzip |
| NamedMpkFileRecorder | File - Named Message Park | None |
| NamedMpkGzFileRecorder | File - Named Message Park | Gzip |
| BinFileRecorder | File - Binary | None |
| BinGzFileRecorder | File - Binary | Gzip |
| JsonGzFileRecorder | File - Json | Gzip |
| PrettyJsonFileRecorder | File - Pretty Json | Gzip |
| BinBytesRecorder | In Memory - Binary | None |

Each recorder supports precision settings decoupled from the precision used for training or
inference. These settings allow you to define the floating-point and integer types that will be used
for serialization and deserialization. Note that when loading a record into a module, the type
conversion is automatically handled, so you can't encounter errors. The only crucial aspect is using
the same recorder for both serialization and deserialization; otherwise, you will encounter loading
errors.

**Which one should you use?**

- If you want fast serialization and deserialization, choose a recorder without compression. The one
with the lowest file size without compression is the binary format; otherwise, the named message
park could be used.
- If you want to save models for storage, you can use compression, but avoid using the binary
format, as it may not be backward compatible.
- If you want to debug your model's weights, you can use the pretty JSON format.
- If you want to deploy with `no-std`, use the in-memory binary format and include the bytes with
the compiled code.

0 comments on commit ab4d2f8

Please sign in to comment.