Skip to content

Commit

Permalink
Add missing docs and enable missing_docs warn lint (tracel-ai#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora authored Jun 21, 2023
1 parent c4e4c25 commit eda241f
Show file tree
Hide file tree
Showing 73 changed files with 696 additions and 26 deletions.
1 change: 1 addition & 0 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{grads::Gradients, graph::backward::backward, tensor::ADTensor};
use burn_tensor::backend::{ADBackend, Backend};

/// A decorator for a backend that enables automatic differentiation.
#[derive(Clone, Copy, Debug, Default)]
pub struct ADBackendDecorator<B> {
_b: B,
Expand Down
9 changes: 7 additions & 2 deletions burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
tensor::ADTensor,
};

/// Gradient identifier.
pub type GradID = String;

/// Gradients container used during the backward pass.
Expand All @@ -15,6 +16,7 @@ pub struct Gradients {
type TensorPrimitive<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;

impl Gradients {
/// Creates a new gradients container.
pub fn new<B: Backend, const D: usize>(
root_node: NodeRef,
root_tensor: TensorPrimitive<B, D>,
Expand All @@ -28,7 +30,8 @@ impl Gradients {
);
gradients
}
/// Consume the gradients for a given tensor.

/// Consumes the gradients for a given tensor.
///
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
/// backward pass, otherwise, it may be consume multiple times.
Expand All @@ -48,7 +51,7 @@ impl Gradients {
}
}

/// Remove a grad tensor from the container.
/// Removes a grad tensor from the container.
pub fn remove<B: Backend, const D: usize>(
&mut self,
tensor: &ADTensor<B, D>,
Expand All @@ -58,6 +61,7 @@ impl Gradients {
.map(|tensor| tensor.into_primitive())
}

/// Gets a grad tensor from the container.
pub fn get<B: Backend, const D: usize>(
&self,
tensor: &ADTensor<B, D>,
Expand All @@ -67,6 +71,7 @@ impl Gradients {
.map(|tensor| tensor.into_primitive())
}

/// Registers a grad tensor in the container.
pub fn register<B: Backend, const D: usize>(
&mut self,
node: NodeRef,
Expand Down
9 changes: 9 additions & 0 deletions burn-autodiff/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
#![warn(missing_docs)]

//! # Burn Autodiff
//!
//! This autodiff library is a part of the Burn project. It is a standalone crate
//! that can be used to perform automatic differentiation on tensors. It is
//! designed to be used with the Burn Tensor crate, but it can be used with any
//! tensor library that implements the `Backend` trait.
#[macro_use]
extern crate derive_new;

Expand Down
2 changes: 2 additions & 0 deletions burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(missing_docs)]

mod add;
mod aggregation;
mod avgpool1d;
Expand Down
5 changes: 4 additions & 1 deletion burn-common/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or `no_std` enabled). No other code should be placed in this package unless unavoidable.
# Burn Common

The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or
`no_std` enabled). No other code should be placed in this package unless unavoidable.

The package must build with `cargo build --no-default-features` as well.
2 changes: 2 additions & 0 deletions burn-common/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use crate::rand::{get_seeded_rng, Rng, SEED};

use uuid::{Builder, Bytes};

/// Simple ID generator.
pub struct IdGenerator {}

impl IdGenerator {
/// Generates a new ID in the form of a UUID.
pub fn generate() -> String {
let mut seed = SEED.lock().unwrap();
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
Expand Down
11 changes: 11 additions & 0 deletions burn-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]

//! # Burn Common Library
//!
//! This library contains common types used by other Burn crates that must be shared.
/// Id module contains types for unique identifiers.
pub mod id;

/// Rand module contains types for random number generation for non-std environments and for
/// std environments.
pub mod rand;

/// Stub module contains types for stubs for non-std environments and for std environments.
pub mod stub;

extern crate alloc;
2 changes: 2 additions & 0 deletions burn-common/src/rand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ use crate::stub::Mutex;
#[cfg(not(feature = "std"))]
use const_random::const_random;

/// Returns a seeded random number generator using entropy.
#[cfg(feature = "std")]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {
StdRng::from_entropy()
}

/// Returns a seeded random number generator using a pre-generated seed.
#[cfg(not(feature = "std"))]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {
Expand Down
23 changes: 20 additions & 3 deletions burn-common/src/stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,68 @@ use spin::{
Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard,
};

// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
/// A mutual exclusion primitive useful for protecting shared data
///
/// This mutex will block threads waiting for the lock to become available. The
/// mutex can also be statically initialized or created via a [Mutex::new]
///
/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap
#[derive(Debug)]
pub struct Mutex<T> {
inner: MutexImported<T>,
}

impl<T> Mutex<T> {
/// Creates a new mutex in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
inner: MutexImported::new(value),
}
}

/// Locks the mutex blocking the current thread until it is able to do so.
#[inline(always)]
pub fn lock(&self) -> Result<MutexGuard<T>, alloc::string::String> {
Ok(self.inner.lock())
}
}

// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
/// A reader-writer lock which is exclusively locked for writing or shared for reading.
/// This reader-writer lock will block threads waiting for the lock to become available.
/// The lock can also be statically initialized or created via a [RwLock::new]
/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap
#[derive(Debug)]
pub struct RwLock<T> {
inner: RwLockImported<T>,
}

impl<T> RwLock<T> {
/// Creates a new reader-writer lock in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
inner: RwLockImported::new(value),
}
}

/// Locks this rwlock with shared read access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn read(&self) -> Result<RwLockReadGuard<T>, alloc::string::String> {
Ok(self.inner.read())
}

/// Locks this rwlock with exclusive write access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn write(&self) -> Result<RwLockWriteGuard<T>, alloc::string::String> {
Ok(self.inner.write())
}
}

// ThreadId stub when no std is available
/// A unique identifier for a running thread.
///
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub struct ThreadId(core::num::NonZeroU64);
41 changes: 41 additions & 0 deletions burn-core/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use alloc::{format, string::String, string::ToString};
pub use burn_derive::Config;

/// Configuration IO error.
#[derive(Debug)]
pub enum ConfigError {
/// Invalid format.
InvalidFormat(String),

/// File not found.
FileNotFound(String),
}

Expand All @@ -28,19 +32,47 @@ impl core::fmt::Display for ConfigError {
#[cfg(feature = "std")]
impl std::error::Error for ConfigError {}

/// Configuration trait.
pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
/// Saves the configuration to a file.
///
/// # Arguments
///
/// * `file` - File to save the configuration to.
///
/// # Returns
///
/// The output of the save operation.
#[cfg(feature = "std")]
fn save(&self, file: &str) -> std::io::Result<()> {
std::fs::write(file, config_to_json(self))
}

/// Loads the configuration from a file.
///
/// # Arguments
///
/// * `file` - File to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
#[cfg(feature = "std")]
fn load(file: &str) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(file)
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
config_from_str(&content)
}

/// Loads the configuration from a binary buffer.
///
/// # Arguments
///
/// * `data` - Binary buffer to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
let content = core::str::from_utf8(data).map_err(|_| {
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
Expand All @@ -49,6 +81,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
}
}

/// Converts a configuration to a JSON string.
///
/// # Arguments
///
/// * `config` - Configuration to convert.
///
/// # Returns
///
/// The JSON string.
pub fn config_to_json<C: Config>(config: &C) -> String {
serde_json::to_string_pretty(config).unwrap()
}
Expand Down
8 changes: 8 additions & 0 deletions burn-core/src/data/dataloader/base.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
pub use crate::data::dataset::{Dataset, DatasetIterator};
use core::iter::Iterator;

/// A progress struct that can be used to track the progress of a data loader.
#[derive(Clone, Debug)]
pub struct Progress {
/// The number of items that have been processed.
pub items_processed: usize,

/// The total number of items that need to be processed.
pub items_total: usize,
}

/// A data loader iterator that can be used to iterate over a data loader.
pub trait DataLoaderIterator<O>: Iterator<Item = O> {
/// Returns the progress of the data loader.
fn progress(&self) -> Progress;
}

/// A data loader that can be used to iterate over a dataset.
pub trait DataLoader<O> {
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
}
37 changes: 37 additions & 0 deletions burn-core/src/data/dataloader/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use super::{
use burn_dataset::{transform::PartialDataset, Dataset};
use std::sync::Arc;

/// A data loader that can be used to iterate over a dataset in batches.
pub struct BatchDataLoader<I, O> {
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
}

/// A data loader iterator that can be used to iterate over a data loader.
struct BatchDataloaderIterator<I, O> {
current_index: usize,
strategy: Box<dyn BatchStrategy<I>>,
Expand All @@ -19,6 +21,17 @@ struct BatchDataloaderIterator<I, O> {
}

impl<I, O> BatchDataLoader<I, O> {
/// Creates a new batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The batch data loader.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
Expand All @@ -31,11 +44,24 @@ impl<I, O> BatchDataLoader<I, O> {
}
}
}

impl<I, O> BatchDataLoader<I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + Sync + Clone + 'static,
{
/// Creates a new multi-threaded batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `num_threads` - The number of threads.
///
/// # Returns
///
/// The multi-threaded batch data loader.
pub fn multi_thread(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
Expand Down Expand Up @@ -65,6 +91,17 @@ impl<I, O> DataLoader<O> for BatchDataLoader<I, O> {
}

impl<I, O> BatchDataloaderIterator<I, O> {
/// Creates a new batch data loader iterator.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The batch data loader iterator.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
Expand Down
10 changes: 10 additions & 0 deletions burn-core/src/data/dataloader/batcher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
/// A trait for batching items of type `I` into items of type `O`.
pub trait Batcher<I, O>: Send + Sync {
/// Batches the given items.
///
/// # Arguments
///
/// * `items` - The items to batch.
///
/// # Returns
///
/// The batched items.
fn batch(&self, items: Vec<I>) -> O;
}

Expand Down
Loading

0 comments on commit eda241f

Please sign in to comment.