-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathars.py
153 lines (139 loc) · 5.23 KB
/
ars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Tuple, Optional, Union
import jax
import jax.numpy as jnp
import chex
from flax import struct
from ..strategy import Strategy
from ..core import GradientOptimizer, OptState, OptParams, exp_decay
@struct.dataclass
class EvoState:
mean: chex.Array
sigma: float
opt_state: OptState
best_member: chex.Array
best_fitness: float = jnp.finfo(jnp.float32).max
gen_counter: int = 0
@struct.dataclass
class EvoParams:
opt_params: OptParams
sigma_init: float = 0.03
sigma_decay: float = 0.999
sigma_limit: float = 0.01
init_min: float = 0.0
init_max: float = 0.0
clip_min: float = -jnp.finfo(jnp.float32).max
clip_max: float = jnp.finfo(jnp.float32).max
class ARS(Strategy):
def __init__(
self,
popsize: int,
num_dims: Optional[int] = None,
pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
elite_ratio: float = 0.1,
opt_name: str = "adam",
lrate_init: float = 0.05,
lrate_decay: float = 1.0,
lrate_limit: float = 0.001,
sigma_init: float = 0.03,
sigma_decay: float = 1.0,
sigma_limit: float = 0.01,
mean_decay: float = 0.0,
n_devices: Optional[int] = None,
**fitness_kwargs: Union[bool, int, float]
):
"""Augmented Random Search (Mania et al., 2018)
Reference: https://arxiv.org/pdf/1803.07055.pdf"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
)
assert not self.popsize & 1, "Population size must be even"
# ARS performs antithetic sampling & allows you to select
# "b" elite perturbation directions for the update
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize / 2 * self.elite_ratio))
assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "ARS"
# Set core kwargs es_params (lrate/sigma schedules)
self.lrate_init = lrate_init
self.lrate_decay = lrate_decay
self.lrate_limit = lrate_limit
self.sigma_init = sigma_init
self.sigma_decay = sigma_decay
self.sigma_limit = sigma_limit
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
opt_params = self.optimizer.default_params.replace(
lrate_init=self.lrate_init,
lrate_decay=self.lrate_decay,
lrate_limit=self.lrate_limit,
)
return EvoParams(
opt_params=opt_params,
sigma_init=self.sigma_init,
sigma_decay=self.sigma_decay,
sigma_limit=self.sigma_limit,
)
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
"""`initialize` the evolution strategy."""
initialization = jax.random.uniform(
rng,
(self.num_dims,),
minval=params.init_min,
maxval=params.init_max,
)
state = EvoState(
mean=initialization,
sigma=params.sigma_init,
opt_state=self.optimizer.initialize(params.opt_params),
best_member=initialization,
)
return state
def ask_strategy(
self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
# Antithetic sampling of noise
z_plus = jax.random.normal(
rng,
(int(self.popsize / 2), self.num_dims),
)
z = jnp.concatenate([z_plus, -1.0 * z_plus])
x = state.mean + state.sigma * z
return x, state
def tell_strategy(
self,
x: chex.Array,
fitness: chex.Array,
state: EvoState,
params: EvoParams,
) -> EvoState:
"""`tell` performance data for strategy state update."""
# Reconstruct noise from last mean/std estimates
noise = (x - state.mean) / state.sigma
noise_1 = noise[: int(self.popsize / 2)]
fit_1 = fitness[: int(self.popsize / 2)]
fit_2 = fitness[int(self.popsize / 2) :]
elite_idx = jnp.minimum(fit_1, fit_2).argsort()[: self.elite_popsize]
fitness_elite = jnp.concatenate([fit_1[elite_idx], fit_2[elite_idx]])
# Add small constant to ensure non-zero division stability
sigma_fitness = jnp.std(fitness_elite) + 1e-05
fit_diff = fit_1[elite_idx] - fit_2[elite_idx]
fit_diff_noise = jnp.dot(noise_1[elite_idx].T, fit_diff)
theta_grad = 1.0 / (self.elite_popsize * sigma_fitness) * fit_diff_noise
# Grad update using optimizer instance - decay lrate if desired
mean, opt_state = self.optimizer.step(
state.mean, theta_grad, state.opt_state, params.opt_params
)
opt_state = self.optimizer.update(opt_state, params.opt_params)
sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(mean=mean, sigma=sigma, opt_state=opt_state)