Skip to content

Commit

Permalink
Performance Enhancements (argmin-rs#112)
Browse files Browse the repository at this point in the history
* feat: make timer optional. Optimise perf when no observers or checkpoints

* feat: ArgminCheckpoint.store filename param changed to &str for flexibility.

* perf: ArgminCheckpoint filename memoized inside ArgminCheckpoint for performance
  • Loading branch information
sdd authored Apr 29, 2021
1 parent 6deb382 commit 16552cd
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
62 changes: 48 additions & 14 deletions src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub struct Executor<O: ArgminOp, S> {
checkpoint: ArgminCheckpoint,
/// Indicates whether Ctrl-C functionality should be active or not
ctrlc: bool,
/// Indicates whether to time execution or not
timer: bool,
}

impl<O, S> Executor<O, S>
Expand All @@ -54,6 +56,7 @@ where
observers: Observer::new(),
checkpoint: ArgminCheckpoint::default(),
ctrlc: true,
timer: true,
}
}

Expand Down Expand Up @@ -104,7 +107,11 @@ where

/// Run the executor
pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
let total_time = instant::Instant::now();
let total_time = if self.timer {
Some(instant::Instant::now())
} else {
None
};

let running = Arc::new(AtomicBool::new(true));

Expand Down Expand Up @@ -132,16 +139,21 @@ where
// let mut op_wrapper = OpWrapper::new(&self.op);
let init_data = self.solver.init(&mut self.op, &self.state)?;

let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););

// If init() returned something, deal with it
if let Some(data) = init_data {
if let Some(data) = &init_data {
self.update(&data)?;
logs = logs.merge(&mut data.get_kv());
}

// Observe after init
self.observers.observe_init(S::NAME, &logs)?;
if !self.observers.is_empty() {
let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););

if let Some(data) = init_data {
logs = logs.merge(&mut data.get_kv());
}

// Observe after init
self.observers.observe_init(S::NAME, &logs)?;
}

self.state.set_func_counts(&self.op);

Expand All @@ -162,29 +174,45 @@ where
}

// Start time measurement
let start = instant::Instant::now();
let start = if self.timer {
Some(instant::Instant::now())
} else {
None
};

let data = self.solver.next_iter(&mut self.op, &self.state)?;

self.state.set_func_counts(&self.op);

// End time measurement
let duration = start.elapsed();
let duration = if self.timer {
Some(start.unwrap().elapsed())
} else {
None
};

self.update(&data)?;

let log = data.get_kv().merge(&mut make_kv!(
"time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
));
if !self.observers.is_empty() {
let mut log = data.get_kv();

self.observers.observe_iter(&self.state, &log)?;
if self.timer {
let duration = duration.unwrap();
log = log.merge(&mut make_kv!(
"time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
));
}
self.observers.observe_iter(&self.state, &log)?;
}

// increment iteration number
self.state.increment_iter();

self.checkpoint.store_cond(&self, self.state.get_iter())?;

self.state.time(total_time.elapsed());
if self.timer {
total_time.map(|total_time| self.state.time(Some(total_time.elapsed())));
}

// Check if termination occured inside next_iter()
if self.state.terminated() {
Expand Down Expand Up @@ -270,4 +298,10 @@ where
self.ctrlc = ctrlc;
self
}

/// Turn timer on or off (default: on)
pub fn timer(mut self, timer: bool) -> Self {
self.timer = timer;
self
}
}
8 changes: 4 additions & 4 deletions src/core/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct IterState<O: ArgminOp> {
/// Number of modify evaluations so far
pub modify_func_count: u64,
/// Time required so far
pub time: instant::Duration,
pub time: Option<instant::Duration>,
/// Reason of termination
pub termination_reason: TerminationReason,
}
Expand Down Expand Up @@ -137,7 +137,7 @@ impl<O: ArgminOp> IterState<O> {
hessian_func_count: 0,
jacobian_func_count: 0,
modify_func_count: 0,
time: instant::Duration::new(0, 0),
time: Some(instant::Duration::new(0, 0)),
termination_reason: TerminationReason::NotTerminated,
}
}
Expand Down Expand Up @@ -213,7 +213,7 @@ impl<O: ArgminOp> IterState<O> {
TerminationReason,
"Set termination_reason"
);
setter!(time, instant::Duration, "Set time required so far");
setter!(time, Option<instant::Duration>, "Set time required so far");
getter!(param, O::Param, "Returns current parameter vector");
getter!(prev_param, O::Param, "Returns previous parameter vector");
getter!(best_param, O::Param, "Returns best parameter vector");
Expand Down Expand Up @@ -270,7 +270,7 @@ impl<O: ArgminOp> IterState<O> {
TerminationReason,
"Get termination_reason"
);
getter!(time, instant::Duration, "Get time required so far");
getter!(time, Option<instant::Duration>, "Get time required so far");
getter_option!(grad, O::Param, "Returns gradient");
getter_option!(prev_grad, O::Param, "Returns previous gradient");
getter_option!(hessian, O::Hessian, "Returns current Hessian");
Expand Down
5 changes: 5 additions & 0 deletions src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ impl<O: ArgminOp> Observer<O> {
self.observers.push((Arc::new(Mutex::new(observer)), mode));
self
}

/// Returns true if `observers` is empty
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
}

/// By implementing `Observe` for `Observer` we basically allow a set of `Observer`s to be used
Expand Down
15 changes: 10 additions & 5 deletions src/core/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct ArgminCheckpoint {
mode: CheckpointMode,
directory: String,
name: String,
filename: String,
}

impl Default for ArgminCheckpoint {
Expand All @@ -57,6 +58,7 @@ impl Default for ArgminCheckpoint {
mode: CheckpointMode::Never,
directory: ".checkpoints".to_string(),
name: "default".to_string(),
filename: "default.arg".to_string(),
}
}
}
Expand All @@ -71,11 +73,13 @@ impl ArgminCheckpoint {
_ => {}
}
let name = "solver".to_string();
let filename = "solver.arg".to_string();
let directory = directory.to_string();
Ok(ArgminCheckpoint {
mode,
directory,
name,
filename,
})
}

Expand All @@ -95,6 +99,9 @@ impl ArgminCheckpoint {
#[inline]
pub fn set_name(&mut self, name: &str) {
self.name = name.to_string();
let mut filename = self.name();
filename.push_str(".arg");
self.filename = filename;
}

/// Get name of checkpoint
Expand All @@ -111,7 +118,7 @@ impl ArgminCheckpoint {

/// Write checkpoint to disk
#[inline]
pub fn store<T: Serialize>(&self, executor: &T, filename: String) -> Result<(), Error> {
pub fn store<T: Serialize>(&self, executor: &T, filename: &str) -> Result<(), Error> {
let dir = Path::new(&self.directory);
if !dir.exists() {
std::fs::create_dir_all(&dir)?
Expand All @@ -126,11 +133,9 @@ impl ArgminCheckpoint {
/// Write checkpoint based on the desired `CheckpointMode`
#[inline]
pub fn store_cond<T: Serialize>(&self, executor: &T, iter: u64) -> Result<(), Error> {
let mut filename = self.name();
filename.push_str(".arg");
match self.mode {
CheckpointMode::Always => self.store(executor, filename)?,
CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, filename)?,
CheckpointMode::Always => self.store(executor, &self.filename)?,
CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, &self.filename)?,
CheckpointMode::Never | CheckpointMode::Every(_) => {}
};
Ok(())
Expand Down

0 comments on commit 16552cd

Please sign in to comment.