Skip to content

Commit

Permalink
Add Clone trait to the OptimizerAdaptor and Clone implementations t…
Browse files Browse the repository at this point in the history
…o the optimizers (tracel-ai#1770)
  • Loading branch information
getumen authored May 15, 2024
1 parent f8a1356 commit e823338
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions crates/burn-core/src/grad_clipping/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl GradientClippingConfig {
/// Gradient Clipping provides a way to mitigate exploding gradients
/// by clipping every component of the gradient by value or by norm during
/// backpropagation.
#[derive(Clone)]
pub enum GradientClipping {
/// Clip the gradient by value.
Value(f32),
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/optim/adagrad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct AdaGradConfig {
}

/// AdaGrad optimizer
#[derive(Clone)]
pub struct AdaGrad<B: Backend> {
lr_decay: LrDecay,
weight_decay: Option<WeightDecay<B>>,
Expand Down Expand Up @@ -105,6 +106,7 @@ pub struct LrDecayState<B: Backend, const D: usize> {
sum: Tensor<B, D>,
}

#[derive(Clone)]
struct LrDecay {
lr_decay: f64,
epsilon: f32,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct AdamConfig {
}

/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
#[derive(Clone)]
pub struct Adam<B: Backend> {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay<B>>,
Expand Down Expand Up @@ -113,6 +114,7 @@ pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
moment_2: Tensor<B, D>,
}

#[derive(Clone)]
struct AdaptiveMomentum {
beta_1: f32,
beta_2: f32,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/optim/adamw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct AdamWConfig {
}

/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
#[derive(Clone)]
pub struct AdamW<B: Backend> {
momentum: AdaptiveMomentumW,
weight_decay: f32,
Expand Down Expand Up @@ -112,6 +113,7 @@ pub struct AdaptiveMomentumWState<B: Backend, const D: usize> {
moment_2: Tensor<B, D>,
}

#[derive(Clone)]
struct AdaptiveMomentumW {
beta_1: f32,
beta_2: f32,
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/optim/decay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct WeightDecayState<B: Backend, const D: usize> {
}

/// Weight decay implementation that transforms gradients.
#[derive(Clone)]
pub struct WeightDecay<B: Backend> {
penalty: B::FloatElem,
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/optim/momentum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct MomentumState<B: Backend, const D: usize> {
}

/// Momemtum implementation that transforms gradients.
#[derive(Clone)]
pub struct Momentum<B: Backend> {
momentum: B::FloatElem,
dampening: f64,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/optim/rmsprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ impl RmsPropConfig {

/// Optimizer that implements stochastic gradient descent with momentum.
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
#[derive(Clone)]
pub struct RmsProp<B: Backend> {
alpha: f32,
// epsilon: f32,
Expand Down Expand Up @@ -251,6 +252,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {

/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
#[derive(Clone)]
pub struct RmsPropMomentum {
momentum: f32,
epsilon: f32,
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct SgdConfig {
/// Optimizer that implements stochastic gradient descent with momentum.
///
/// The optimizer can be configured with [SgdConfig](SgdConfig).
#[derive(Clone)]
pub struct Sgd<B: Backend> {
momentum: Option<Momentum<B>>,
weight_decay: Option<WeightDecay<B>>,
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/optim/simple/adaptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use hashbrown::HashMap;

/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
/// an [optimizer](Optimizer).
#[derive(Clone)]
pub struct OptimizerAdaptor<O, M, B>
where
O: SimpleOptimizer<B::InnerBackend>,
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/optim/simple/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, Tensor};
///
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
/// module parameter structure, handle tracked and untracked tensors, and the likes.
pub trait SimpleOptimizer<B>: Send + Sync
pub trait SimpleOptimizer<B>: Send + Sync + Clone
where
B: Backend,
{
Expand Down

0 comments on commit e823338

Please sign in to comment.