Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Prev Previous commit
Next Next commit
Fix old checkpoints were not removed
  • Loading branch information
nathanielsimard committed Apr 4, 2023
commit 382c83e70fd15872f03926f6ad7cba31efbc0317
17 changes: 17 additions & 0 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ impl<B: Backend> SimpleOptimizer<B> for Adam<B> {

(tensor - delta, Some(state))
}

fn to_device<const D: usize>(
mut state: Self::State<D>,
device: &<B as Backend>::Device,
) -> Self::State<D> {
state.weight_decay = state.weight_decay.map(|state| state.to_device(device));
state.momentum = state.momentum.to_device(device);
state
}
}

impl AdamConfig {
Expand Down Expand Up @@ -153,6 +162,14 @@ impl AdaptiveMomentum {
}
}

impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
pub fn to_device(mut self, device: &B::Device) -> Self {
self.moment_1 = self.moment_1.to_device(device);
self.moment_2 = self.moment_2.to_device(device);
self
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
7 changes: 7 additions & 0 deletions burn-core/src/optim/decay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ impl<B: Backend> WeightDecay<B> {
(grad, WeightDecayState::new(grad_last_step))
}
}

impl<B: Backend, const D: usize> WeightDecayState<B, D> {
pub fn to_device(mut self, device: &B::Device) -> Self {
self.grad_last_step = self.grad_last_step.to_device(device);
self
}
}
7 changes: 7 additions & 0 deletions burn-core/src/optim/momentum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ impl<B: Backend> Momentum<B> {
(grad, MomemtumState::new(velocity))
}
}

impl<B: Backend, const D: usize> MomemtumState<B, D> {
pub fn to_device(mut self, device: &B::Device) -> Self {
self.velocity = self.velocity.to_device(device);
self
}
}
6 changes: 6 additions & 0 deletions burn-core/src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {

(tensor - delta, Some(state))
}

fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
state.weight_decay = state.weight_decay.map(|state| state.to_device(device));
state.momentum = state.momentum.map(|state| state.to_device(device));
state
}
}

#[cfg(test)]
Expand Down
5 changes: 4 additions & 1 deletion burn-core/src/optim/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ where
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);

fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
}

pub struct SimpleModuleOptimizer<O, M, B>
Expand Down Expand Up @@ -103,11 +105,12 @@ where
let grad = self.grads.remove(id);

if let Some(grad) = grad {
let device = grad.device();
let (key, record) = self.records.remove_entry(id).unzip();
let (tensor, state) = self.optimizer.step(
tensor.inner(),
grad,
record.map(|record| record.into_state()),
record.map(|record| O::to_device(record.into_state(), &device)),
);

if let Some(state) = state {
Expand Down
47 changes: 34 additions & 13 deletions burn-core/src/record/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{fs::File, path::PathBuf};
pub trait FileRecorder:
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
{
fn file_extension() -> &'static str;
}

/// File recorder using the [bincode format](bincode).
Expand All @@ -26,13 +27,33 @@ pub struct FilePrettyJsonRecorder;
/// File recorder using the [message pack](rmp_serde) format compressed with gzip.
pub struct FileMpkGzRecorder;

impl FileRecorder for FileBinGzRecorder {}
impl FileRecorder for FileBinRecorder {}
impl FileRecorder for FileJsonGzRecorder {}
impl FileRecorder for FilePrettyJsonRecorder {}
impl FileRecorder for FileBinGzRecorder {
fn file_extension() -> &'static str {
"bin.gz"
}
}
impl FileRecorder for FileBinRecorder {
fn file_extension() -> &'static str {
"bin"
}
}
impl FileRecorder for FileJsonGzRecorder {
fn file_extension() -> &'static str {
"json.gz"
}
}
impl FileRecorder for FilePrettyJsonRecorder {
fn file_extension() -> &'static str {
"json"
}
}

#[cfg(feature = "msgpack")]
impl FileRecorder for FileMpkGzRecorder {}
impl FileRecorder for FileMpkGzRecorder {
fn file_extension() -> &'static str {
"mpk.gz"
}
}

macro_rules! str2reader {
(
Expand Down Expand Up @@ -79,7 +100,7 @@ impl Recorder for FileBinGzRecorder {
mut file: PathBuf,
) -> Result<(), RecorderError> {
let config = bin_config();
let writer = str2writer!(file, "bin.gz")?;
let writer = str2writer!(file, Self::file_extension())?;
let mut writer = GzEncoder::new(writer, Compression::default());

bincode::serde::encode_into_std_write(&obj, &mut writer, config)
Expand All @@ -89,7 +110,7 @@ impl Recorder for FileBinGzRecorder {
}

fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "bin.gz")?;
let reader = str2reader!(file, Self::file_extension())?;
let mut reader = GzDecoder::new(reader);
let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Expand Down Expand Up @@ -131,7 +152,7 @@ impl Recorder for FileJsonGzRecorder {
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let writer = str2writer!(file, "json.gz")?;
let writer = str2writer!(file, Self::file_extension())?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &obj)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Expand All @@ -140,7 +161,7 @@ impl Recorder for FileJsonGzRecorder {
}

fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "json.gz")?;
let reader = str2reader!(file, Self::file_extension())?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Expand All @@ -158,14 +179,14 @@ impl Recorder for FilePrettyJsonRecorder {
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let writer = str2writer!(file, "json")?;
let writer = str2writer!(file, Self::file_extension())?;
serde_json::to_writer_pretty(writer, &obj)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}

fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "json")?;
let reader = str2reader!(file, Self::file_extension())?;
let state = serde_json::from_reader(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;

Expand All @@ -183,7 +204,7 @@ impl Recorder for FileMpkGzRecorder {
obj: Obj,
mut file: PathBuf,
) -> Result<(), RecorderError> {
let writer = str2writer!(file, "mpk.gz")?;
let writer = str2writer!(file, Self::file_extension())?;
let mut writer = GzEncoder::new(writer, Compression::default());
rmp_serde::encode::write(&mut writer, &obj)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Expand All @@ -192,7 +213,7 @@ impl Recorder for FileMpkGzRecorder {
}

fn load<Obj: Serialize + DeserializeOwned>(mut file: PathBuf) -> Result<Obj, RecorderError> {
let reader = str2reader!(file, "mpk.gz")?;
let reader = str2reader!(file, Self::file_extension())?;
let reader = GzDecoder::new(reader);
let state = rmp_serde::decode::from_read(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Expand Down
16 changes: 13 additions & 3 deletions burn-train/src/checkpoint/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,20 @@ where
}

let file_path_old_checkpoint = self.path_for_epoch(epoch - self.num_keep);
let file_to_remove = format!(
"{}.{}",
file_path_old_checkpoint,
<S::Recorder as FileRecorder>::file_extension()
);

if std::path::Path::new(&file_path_old_checkpoint).exists() {
log::info!("Removing checkpoint {}", file_path_old_checkpoint);
std::fs::remove_file(file_path_old_checkpoint).map_err(CheckpointerError::IOError)?;
match std::fs::remove_file(file_to_remove) {
Ok(_) => log::info!("Removed checkpoint {}", file_path_old_checkpoint),
Err(err) => {
match err.kind() {
std::io::ErrorKind::NotFound => (), // Ignoring missing old checkpoints,
_ => return Err(CheckpointerError::IOError(err)),
}
}
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub fn run<B: ADBackend>(device: B::Device) {
.metric_valid_plot(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<DefaultRecordSettings>(2)
.with_file_checkpointer::<DefaultRecordSettings>(1)
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, optim);
Expand Down