Skip to content

Commit

Permalink
refactor: save and load state (tracel-ai#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 24, 2022
1 parent 3a91c2c commit 1a1d86d
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 146 deletions.
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::ops::{Ones, Zeros};
use crate::{tensor::Shape, Element, ElementConversion};
use rand::{distributions::Standard, prelude::StdRng, Rng};

#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct DataSerialize<P> {
pub value: Vec<P>,
pub shape: Vec<usize>,
Expand Down
4 changes: 3 additions & 1 deletion burn/src/module/param/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
use serde::{Deserialize, Serialize};

#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ParamId {
value: String,
}
Expand Down
2 changes: 2 additions & 0 deletions burn/src/module/param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ mod base;
mod id;
mod module;
mod tensor;
mod visitor;

pub use base::*;
pub use id::*;
pub use module::*;
pub use tensor::*;
pub use visitor::*;
23 changes: 23 additions & 0 deletions burn/src/module/param/visitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use super::ParamId;
use crate::module::{Module, ModuleVisitor};
use burn_tensor::{backend::Backend, Tensor};

#[derive(new)]
struct ParamIdCollector<'a> {
param_ids: &'a mut Vec<ParamId>,
}

impl<'a, B: Backend> ModuleVisitor<B> for ParamIdCollector<'a> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.param_ids.push(id.clone());
}
}

/// List all the parameter ids in a module.
pub fn list_param_ids<M: Module>(module: &M) -> Vec<ParamId> {
let mut params_ids = Vec::new();
let mut visitor = ParamIdCollector::new(&mut params_ids);
module.visit(&mut visitor);

params_ids
}
179 changes: 58 additions & 121 deletions burn/src/module/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use crate::tensor::{DataSerialize, Element};
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::collections::HashMap;
use std::io::{Read, Write};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fs::File, path::Path};

#[derive(Debug, PartialEq, Eq, Default)]
#[derive(Debug, PartialEq, Eq, Clone, Default, Serialize, Deserialize)]
pub struct StateNamed<E> {
pub values: HashMap<String, State<E>>,
}

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum State<E> {
StateNamed(StateNamed<E>),
Data(DataSerialize<E>),
Expand Down Expand Up @@ -40,90 +40,8 @@ impl std::fmt::Display for StateError {
f.write_str(message.as_str())
}
}
impl std::error::Error for StateError {}

impl<E: Element> From<State<E>> for serde_json::Value
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
fn from(state: State<E>) -> serde_json::Value {
match state {
State::StateNamed(state) => state.into(),
State::Data(data) => serde_json::to_value(data).unwrap(),
State::ParamId(id) => serde_json::to_value(id.to_string()).unwrap(),
}
}
}

impl<E: Element> From<StateNamed<E>> for serde_json::Value
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
fn from(state: StateNamed<E>) -> serde_json::Value {
let mut map = serde_json::Map::new();

for (key, state) in state.values {
map.insert(key, state.into());
}

serde_json::Value::Object(map)
}
}

impl<E> TryFrom<serde_json::Value> for State<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
type Error = StateError;

fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
if let Ok(data) = serde_json::from_value(value.clone()) {
return Ok(State::Data(data));
};

if let Ok(state) = StateNamed::<E>::try_from(value.clone()) {
return Ok(State::StateNamed(state));
};

match serde_json::from_value::<String>(value.clone()) {
Ok(id) => Ok(State::ParamId(ParamId::from(id.as_str()))),
Err(_) => Err(StateError::InvalidFormat(format!(
"Invalid value {:?}",
value
))),
}
}
}

impl<E> TryFrom<serde_json::Value> for StateNamed<E>
where
E: serde::de::DeserializeOwned,
E: serde::Serialize,
{
type Error = StateError;

fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
let map = match value {
serde_json::Value::Object(map) => map,
_ => {
return Err(StateError::InvalidFormat(format!(
"Invalid value {:?}",
value
)))
}
};

let mut values = HashMap::new();
for (key, value) in map {
values.insert(key, State::try_from(value)?);
}

Ok(Self { values })
}
}
impl std::error::Error for StateError {}

impl<E: Element> StateNamed<E> {
pub fn new() -> Self {
Expand Down Expand Up @@ -188,65 +106,84 @@ where
E: serde::Serialize,
{
pub fn save(self, file: &str) -> std::io::Result<()> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
let value: serde_json::Value = self.into();
let path = Path::new(file);
if path.exists() {
log::info!("File exists, replacing");
std::fs::remove_file(path).unwrap();
}

let content = value.to_string();
encoder.write_all(content.as_bytes()).unwrap();
let content_compressed = encoder.finish().unwrap();
let writer = File::create(path)?;
let writer = GzEncoder::new(writer, Compression::default());
serde_json::to_writer(writer, &self).unwrap();

std::fs::write(file, content_compressed)
Ok(())
}

pub fn load(file: &str) -> Result<Self, StateError> {
let content_compressed =
std::fs::read(file).map_err(|err| StateError::FileNotFound(format!("{:?}", err)))?;
let path = Path::new(file);
let reader =
File::open(path).map_err(|err| StateError::FileNotFound(format!("{:?}", err)))?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader).unwrap();

let mut decoder = GzDecoder::new(content_compressed.as_slice());
let mut content = String::new();
decoder.read_to_string(&mut content).unwrap();

let value: serde_json::Value = serde_json::from_str(&content)
.map_err(|err| StateError::InvalidFormat(format!("{:?}", err)))?;
Self::try_from(value)
Ok(state)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::module::Module;
use crate::nn;
use crate::module::{list_param_ids, Module};
use crate::tensor::backend::Backend;
use crate::{nn, TestBackend};

#[test]
fn test_state_to_from_value() {
let linear = nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
d_input: 32,
d_output: 32,
bias: true,
});
let model = create_model();
let state = model.state();
let bytes = serde_json::to_vec(&state).unwrap();

let state = linear.state();
let value: serde_json::Value = state.into();
println!("{:?}", value);
let state_from: State<<crate::TestBackend as Backend>::Elem> =
State::try_from(value.clone()).unwrap();
let value_from: serde_json::Value = state_from.into();
serde_json::from_slice(&bytes).unwrap();

assert_eq!(value, value_from);
assert_eq!(state, state_from);
}

#[test]
fn test_can_save_and_load_from_file() {
let mut linear = nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
let model_before = create_model();
let state_before = model_before.state();
state_before.clone().save("/tmp/test.json").unwrap();

let mut model_after = create_model();
model_after
.load(&State::load("/tmp/test.json").unwrap())
.unwrap();

let state_after = model_after.state();
assert_eq!(state_before, state_after);
}

#[test]
fn test_parameter_ids_are_loaded() {
let model_1 = create_model();
let mut model_2 = create_model();
let params_before_1 = list_param_ids(&model_1);
let params_before_2 = list_param_ids(&model_2);

let state = model_1.state();
model_2.load(&state).unwrap();
let params_after_2 = list_param_ids(&model_2);

assert_ne!(params_before_1, params_before_2);
assert_eq!(params_before_1, params_after_2);
}

fn create_model() -> nn::Linear<TestBackend> {
nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
d_input: 32,
d_output: 32,
bias: true,
});
linear.state().save("/tmp/test.json").unwrap();
linear
.load(&State::load("/tmp/test.json").unwrap())
.unwrap();
})
}
}
28 changes: 5 additions & 23 deletions burn/src/optim/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ pub struct GradientsParamsChangeDevice<'a, B: ADBackend> {
grads: &'a mut GradientsParams<B>,
}

#[derive(new)]
pub struct ParamIdCollector<'a> {
grads: &'a mut Vec<ParamId>,
}

impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for GradientsRegister<'a, B, O> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.optimizer.register_param_state::<D>(id, self.state)
Expand All @@ -66,12 +61,6 @@ impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for Gradients
}
}

impl<'a, B: ADBackend> ModuleVisitor<B> for ParamIdCollector<'a> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.grads.push(id.clone());
}
}

impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
if let Some(grad) = tensor.grad(&self.grads) {
Expand Down Expand Up @@ -114,37 +103,30 @@ pub fn convert_grads<M: ADModule>(
mod tests {
use super::*;
use crate::{
module::Module,
module::{list_param_ids, Module},
nn::{Linear, LinearConfig},
TestADBackend,
};
use burn_tensor::{backend::Backend, Distribution};

#[test]
fn test_convert_grads_to_params_id() {
fn test_convert_grads() {
let layer_1 = layer();
let mut layer_2 = layer_1.clone();
layer_2.to_device(<TestADBackend as Backend>::Device::default());
layer_2.detach();
let loss_1 = layer_1.forward(random_tensor());
let loss_2 = layer_2.forward(random_tensor());
let mut params_ids_1 = Vec::new();
let mut params_ids_2 = Vec::new();
let mut visitor_1 = ParamIdCollector::new(&mut params_ids_1);
let mut visitor_2 = ParamIdCollector::new(&mut params_ids_2);
let grads_1 = loss_1.backward();
let grads_2 = loss_2.backward();

layer_1.visit(&mut visitor_1);
layer_2.visit(&mut visitor_2);

convert_grads(grads_1, &layer_1);
convert_grads(grads_2, &layer_2);

layer_1.visit(&mut visitor_1);
layer_2.visit(&mut visitor_2);
let param_ids_1 = list_param_ids(&layer_1);
let params_ids_2 = list_param_ids(&layer_2);

assert_eq!(params_ids_1, params_ids_2);
assert_eq!(param_ids_1, params_ids_2);
}

fn layer() -> Linear<TestADBackend> {
Expand Down

0 comments on commit 1a1d86d

Please sign in to comment.