-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy patharchitect.py
69 lines (56 loc) · 2.49 KB
/
architect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import copy
import torch
class Architect():
""" Compute gradients of alphas """
def __init__(self, net, w_momentum, w_weight_decay):
self.net = net
self.v_net = copy.deepcopy(net)
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay
def virtual_step(self, trn_X, trn_y, xi, w_optim):
"""
Compute unrolled weight w' (virtual step)
"""
loss = self.net.loss(trn_X, trn_y)
# compute gradient
gradients = torch.autograd.grad(loss, self.net.weights())
with torch.no_grad():
for w, vw, g in zip(self.net.weights(), self.v_net.weights(), gradients):
m = w_optim.state[w].get('momentum_buffer', 0.) * self.w_momentum
vw.copy_(w - xi * (m + g + self.w_weight_decay*w))
for a, va in zip(self.net.alphas(), self.v_net.alphas()):
va.copy_(a)
def unrolled_backward(self, trn_X, trn_y, val_X, val_y, xi, w_optim):
""" Compute unrolled loss and backward its gradients"""
self.virtual_step(trn_X, trn_y, xi, w_optim)
loss = self.v_net.loss(val_X, val_y)
# compute gradient
v_alphas = tuple(self.v_net.alphas())
v_weights = tuple(self.v_net.weights())
v_grads = torch.autograd.grad(loss, v_alphas + v_weights)
torch.nn.utils.clip_grad_norm_(v_grads, 5.)
dalpha = v_grads[:len(v_alphas)]
dw = v_grads[len(v_alphas):]
hessian = self.compute_hessian(dw, trn_X, trn_y)
with torch.no_grad():
for alpha, da, h in zip(self.net.alphas(), dalpha, hessian):
alpha.grad = da - xi*h
def compute_hessian(self, dw, trn_X, trn_y):
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
# w+ = w + eps*dw`
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p += eps * d
loss = self.net.loss(trn_X, trn_y)
dalpha_pos = torch.autograd.grad(loss, self.net.alphas())
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p -= 2. * eps * d
loss = self.net.loss(trn_X, trn_y)
dalpha_neg = torch.autograd.grad(loss, self.net.alphas())
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p += eps * d
hessian = [(p-n) / 2.*eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian