-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathrm_es.py
338 lines (303 loc) · 10 KB
/
rm_es.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from typing import Tuple, Optional, Union
import jax
import jax.numpy as jnp
import chex
from flax import struct
from ..strategy import Strategy
@struct.dataclass
class EvoState:
p_sigma: chex.Array
mean: chex.Array
sigma: float
P: chex.Array
t_gap: chex.Array
s_rank_rate: float
fitness_archive: chex.Array
weights: chex.Array
best_member: chex.Array
best_fitness: float = jnp.finfo(jnp.float32).max
gen_counter: int = 0
@struct.dataclass
class EvoParams:
c_cov: float
c_c: float
c_sigma: float
mu_eff: float
c_m: float = 1.0
sigma_init: float = 1.0
sigma_limit: float = 0.001
t_uncorr: int = 20
q_star: float = 0.325
c_s: float = 0.3
d_sigma: float = 1.0
init_min: float = 0.0
init_max: float = 0.0
clip_min: float = -jnp.finfo(jnp.float32).max
clip_max: float = jnp.finfo(jnp.float32).max
def get_elite_weights(elite_popsize: int) -> Tuple[chex.Array, chex.Array]:
"""Utility helper to create truncated elite weights for mean update."""
weights = jnp.array(
[
(
(jnp.log(elite_popsize + 1) - jnp.log(i + 1))
/ (
elite_popsize * jnp.log(elite_popsize + 1)
- jnp.sum(jnp.log(jnp.arange(1, elite_popsize + 1)))
)
)
for i in range(elite_popsize)
]
)
return weights
class RmES(Strategy):
def __init__(
self,
popsize: int,
num_dims: Optional[int] = None,
pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
elite_ratio: float = 0.5,
memory_size: int = 10,
sigma_init: float = 1.0,
mean_decay: float = 0.0,
n_devices: Optional[int] = None,
**fitness_kwargs: Union[bool, int, float]
):
"""Rank-m ES (Li & Zhang, 2017)
Reference: https://ieeexplore.ieee.org/document/8080257
"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.memory_size = memory_size # number of ranks
self.strategy_name = "RmES"
# Set core kwargs es_params
self.sigma_init = sigma_init
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
weights = get_elite_weights(self.elite_popsize)
mu_eff = 1 / jnp.sum(weights ** 2)
c_cov = 1 / (3 * jnp.sqrt(self.num_dims) + 5)
c_c = 2 / (self.num_dims + 7)
params = EvoParams(
c_cov=c_cov,
c_c=c_c,
c_sigma=jnp.minimum(2 / (self.num_dims + 7), 0.05),
mu_eff=mu_eff,
sigma_init=self.sigma_init,
)
return params
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
"""`initialize` the evolution strategy."""
weights = get_elite_weights(self.elite_popsize)
# Initialize evolution paths & covariance matrix
initialization = jax.random.uniform(
rng,
(self.num_dims,),
minval=params.init_min,
maxval=params.init_max,
)
state = EvoState(
p_sigma=jnp.zeros(self.num_dims),
sigma=params.sigma_init,
mean=initialization,
P=jnp.zeros((self.num_dims, self.memory_size)),
t_gap=jnp.zeros(self.memory_size),
s_rank_rate=0.0,
weights=weights,
# Store previous generations fitness for rank-based success rule
fitness_archive=jnp.zeros(self.popsize) + 1e20,
best_member=initialization,
)
return state
def ask_strategy(
self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
x = sample(
rng,
state.mean,
state.sigma,
state.P,
self.num_dims,
self.popsize,
params.c_cov,
state.gen_counter,
)
return x, state
def tell_strategy(
self,
x: chex.Array,
fitness: chex.Array,
state: EvoState,
params: EvoParams,
) -> EvoState:
"""`tell` performance data for strategy state update."""
# Sort new results, extract elite, store best performer
concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x])
sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()][
: self.elite_popsize
]
# Update mean, isotropic/anisotropic paths, covariance, stepsize
mean = update_mean(
state.mean, sorted_solutions, params.c_m, state.weights
)
p_sigma = update_p_sigma(
state.mean,
mean,
state.sigma,
state.p_sigma,
params.c_sigma,
params.mu_eff,
)
P, t_gap = update_P_matrix(
state.P,
state.p_sigma,
state.t_gap,
params.t_uncorr,
state.gen_counter,
)
s_rank_rate = rank_success_rule(
fitness,
state.fitness_archive,
state.s_rank_rate,
params.q_star,
state.weights,
params.c_s,
)
sigma = update_sigma(state.sigma, s_rank_rate, params.d_sigma)
sigma = jnp.maximum(sigma, params.sigma_limit)
return state.replace(
mean=mean,
p_sigma=p_sigma,
P=P,
t_gap=t_gap,
sigma=sigma,
s_rank_rate=s_rank_rate,
fitness_archive=fitness,
)
def update_mean(
mean: chex.Array,
sorted_solutions: chex.Array,
c_m: float,
weights: chex.Array,
) -> chex.Array:
"""Update mean of strategy."""
mean = (1 - c_m) * mean + c_m * jnp.sum(
sorted_solutions[:, 1:].T * weights, axis=1
)
return mean
def update_p_sigma(
mean_old: chex.Array,
mean: chex.Array,
sigma: float,
p_sigma: chex.Array,
c_sigma: float,
mu_eff: float,
) -> chex.Array:
"""Update evolution path for covariance matrix."""
p_sigma_new = (1 - c_sigma) * p_sigma + jnp.sqrt(
c_sigma * (2 - c_sigma) * mu_eff
) * (mean - mean_old) / sigma
return p_sigma_new
def update_P_matrix(
P: chex.Array,
p_sigma: chex.Array,
t_gap: chex.Array,
t_uncorr: int,
gen_counter: int,
) -> Tuple[chex.Array, chex.Array]:
"""Update the P matrix storing m evolution paths."""
memory_size = P.shape[1]
# Use evo paths in separated generations - keep them uncorrelated!
T_min = jnp.min(t_gap[1:] - t_gap[:-1])
replace_crit = T_min > t_uncorr
fill_up_crit = gen_counter < memory_size
push_replace = jnp.logical_or(replace_crit, fill_up_crit)
# Case 1: Initially Record all evolution paths - make space for new one
P_c1 = P.at[:, :-1].set(P[:, 1:])
t_gap_c1 = t_gap.at[:-1].set(t_gap[1:])
# Case 2: Remove/overwrite the oldesy recorded evolution path
# Problem: i_min is a dynamic index - write as sum of two matrices?!
P_c2 = P[:]
t_gap_c2 = t_gap[:]
i_min = jnp.argmin(t_gap[:-1] - t_gap[1:])
for i in range(memory_size - 1):
replace_bool = i >= i_min
P_c2 = jax.lax.select(
replace_bool, P_c2.at[:, i].set(P_c2[:, i + 1]), P_c2
)
t_gap_c2 = jax.lax.select(
replace_bool, t_gap_c2.at[i].set(t_gap_c2[i + 1]), t_gap_c2
)
P = jax.lax.select(push_replace, P_c1, P_c2)
t_gap = jax.lax.select(push_replace, t_gap_c1, t_gap_c2)
# Finally update with the most recent evolution path
P = P.at[:, memory_size - 1].set(p_sigma)
t_gap = t_gap.at[memory_size - 1].set(gen_counter)
return P, t_gap
def update_sigma(sigma: float, s_rank_rate: float, d_sigma: float) -> float:
"""Update stepsize sigma."""
sigma_new = sigma * jnp.exp(s_rank_rate / d_sigma)
return sigma_new
def sample(
rng: chex.PRNGKey,
mean: chex.Array,
sigma: float,
P: chex.Array,
n_dim: int,
pop_size: int,
c_cov: chex.Array,
gen_counter: int,
) -> chex.Array:
"""Jittable Gaussian Sample Helper."""
z = jax.random.normal(rng, (n_dim, pop_size)) # ~ N(0, I)
r = jax.random.normal(rng, (n_dim, P.shape[1]))
for j in range(P.shape[1]):
update_bool = gen_counter > j
new_z = (
jnp.sqrt(1 - c_cov) * z
+ (jnp.sqrt(c_cov) * P[:, j])[:, jnp.newaxis]
* r[:, j][:, jnp.newaxis]
)
z = jax.lax.select(update_bool, new_z, z)
z = jnp.swapaxes(z, 1, 0)
x = mean + sigma * z # ~ N(m, σ^2 C)
return x
def rank_success_rule(
fitness: chex.Array,
fitness_archive: chex.Array,
s_rank_rate: float,
q_star: float,
weights: chex.Array,
c_s: float,
) -> float:
"""Compute rank-based success rule (cumulative rank rate)."""
elite_popsize = weights.shape[0]
popsize = fitness.shape[0]
# Step 1: Sort all fitnesses in ascending order and get ranks
# Rank parents + kids jointly - subdivide afterwards & take elite from both
concat_all = jnp.vstack(
[jnp.expand_dims(fitness, 1), jnp.expand_dims(fitness_archive, 1)]
)
ranks = jnp.zeros(concat_all.shape[0])
ranks = ranks.at[concat_all[:, 0].argsort()].set(jnp.arange(2 * popsize))
ranks_current = ranks[:popsize]
ranks_current = ranks_current[ranks_current.argsort()][:elite_popsize]
ranks_last = ranks[popsize:]
ranks_last = ranks_last[ranks_last.argsort()][:elite_popsize]
# Step 2: Compute rank difference (Parents vs. kids) - paper assumes min!
q = 1 / elite_popsize * jnp.sum(weights * (ranks_last - ranks_current))
# Step 3: Compute comulative rank rate using decaying memory
new_s_rank_rate = (1 - c_s) * s_rank_rate + c_s * (q - q_star)
return new_s_rank_rate