diff --git a/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py b/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py index b2bfd29cd2..2b0eeac26d 100644 --- a/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py +++ b/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py @@ -38,6 +38,7 @@ _next_parameter_id = 0 _KEY = 'STEPS' +_epsilon = 1e-6 @unique class OptimizeMode(Enum): @@ -141,8 +142,8 @@ def __init__(self, s, s_max, eta, R, optimize_mode): self.bracket_id = s self.s_max = s_max self.eta = eta - self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1)) # pylint: disable=invalid-name - self.r = math.ceil(R / eta**s) # pylint: disable=invalid-name + self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1) - _epsilon) # pylint: disable=invalid-name + self.r = math.ceil(R / eta**s - _epsilon) # pylint: disable=invalid-name self.i = 0 self.hyper_configs = [] # [ {id: params}, {}, ... ] self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] @@ -157,7 +158,7 @@ def is_completed(self): def get_n_r(self): """return the values of n and r for the next round""" - return math.floor(self.n / self.eta**self.i), self.r * self.eta**self.i + return math.floor(self.n / self.eta**self.i + _epsilon), self.r * self.eta**self.i def increase_i(self): """i means the ith round. Increase i by 1""" @@ -305,7 +306,7 @@ def __init__(self, R, eta=3, optimize_mode='maximize'): self.brackets = dict() # dict of Bracket self.generated_hyper_configs = [] # all the configs waiting for run self.completed_hyper_configs = [] # all the completed configs - self.s_max = math.floor(math.log(self.R, self.eta)) + self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon) self.curr_s = self.s_max self.searchspace_json = None