Skip to content

Commit

Permalink
Path type (tracel-ai#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 authored Aug 27, 2023
1 parent 7f558bd commit 0f7864f
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion burn-book/src/basic-workflow/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Now let's create a simple `infer` method in which we will load our trained model
```rust , ignore
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
let config =
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into())
.expect("Failed to load trained model");
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct TrainingConfig {
pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(&format!("{artifact_dir}/config.json"))
.save(format!("{artifact_dir}/config.json"))
.expect("Save without error");

B::seed(config.seed);
Expand Down
8 changes: 4 additions & 4 deletions burn-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
///
/// The output of the save operation.
#[cfg(feature = "std")]
fn save(&self, file: &str) -> std::io::Result<()> {
fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
std::fs::write(file, config_to_json(self))
}

Expand All @@ -58,9 +58,9 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
///
/// The loaded configuration.
#[cfg(feature = "std")]
fn load(file: &str) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(file)
.map_err(|_| ConfigError::FileNotFound(file.to_string()))?;
fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(file.as_ref().clone())
.map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
config_from_str(&content)
}

Expand Down
2 changes: 1 addition & 1 deletion examples/guide/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use burn_dataset::source::huggingface::MNISTItem;

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem) {
let config =
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
TrainingConfig::load(format!("{artifact_dir}/config.json")).expect("A config exists");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into())
.expect("Failed to load trained model");
Expand Down
2 changes: 1 addition & 1 deletion examples/guide/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct TrainingConfig {
pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(&format!("{artifact_dir}/config.json"))
.save(format!("{artifact_dir}/config.json"))
.expect("Save without error");

B::seed(config.seed);
Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let model_trained = learner.fit(dataloader_train, dataloader_test);

// Save the configuration and the trained model
config.save(&format!("{artifact_dir}/config.json")).unwrap();
config.save(format!("{artifact_dir}/config.json")).unwrap();
CompactRecorder::new()
.record(
model_trained.into_record(),
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(

let model_trained = learner.fit(dataloader_train, dataloader_test);

config.save(&format!("{artifact_dir}/config.json")).unwrap();
config.save(format!("{artifact_dir}/config.json")).unwrap();

DefaultRecorder::new()
.record(
Expand Down

0 comments on commit 0f7864f

Please sign in to comment.