Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/optim #272

Merged
merged 13 commits into from
Apr 5, 2023
Prev Previous commit
Next Next commit
Remove state
  • Loading branch information
nathanielsimard committed Apr 5, 2023
commit 083ee957d9d05aa5618ff1b25df537882bdef78c
2 changes: 0 additions & 2 deletions burn-core/src/module/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod base;
mod param;
mod state;

pub use base::*;
pub use param::*;
pub use state::*;
105 changes: 0 additions & 105 deletions burn-core/src/module/state.rs

This file was deleted.

14 changes: 1 addition & 13 deletions burn-core/src/record/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Record, RecordSettings};
use crate::module::{Param, ParamId, State};
use crate::module::{Param, ParamId};
use alloc::vec::Vec;
use burn_tensor::{DataSerialize, Element};
use hashbrown::HashMap;
Expand Down Expand Up @@ -72,18 +72,6 @@ impl<T: Record> Record for HashMap<ParamId, T> {
}
}

impl<T: Element> Record for State<T> {
type Item<S: RecordSettings> = State<S::FloatElem>;

fn into_item<S: RecordSettings>(self) -> Self::Item<S> {
self.convert::<S::FloatElem>()
}

fn from_item<S: RecordSettings>(item: Self::Item<S>) -> Self {
item.convert()
}
}

impl<E: Element> Record for DataSerialize<E> {
type Item<S: RecordSettings> = DataSerialize<S::FloatElem>;

Expand Down
6 changes: 1 addition & 5 deletions burn-train/src/checkpoint/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
use burn_core::{
module::StateError,
record::{Record, RecorderError},
};
use burn_core::record::{Record, RecorderError};

#[derive(Debug)]
pub enum CheckpointerError {
IOError(std::io::Error),
RecorderError(RecorderError),
StateError(StateError),
Unknown(String),
}

Expand Down