Skip to content

Commit

Permalink
Adding support for differentiable lr, weight_decay, and betas in Adam…
Browse files Browse the repository at this point in the history
…/AdamW (#143726)

Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn)

This is also going to probably be my last PR before the holidays!

Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged.

This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py)

I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this
```py
if differentiable and isinstance(beta2, Tensor):
    if beta2.requires_grad:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
    else:
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
That I could definitely simplify to just
```py
if differentiable and isinstance(beta2, Tensor):
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
else:
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```

It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability.

Also the line in the above example:
```py
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2))
```
was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it.

Further work on #141832

Pull Request resolved: #143726
Approved by: https://github.com/janeyx99
  • Loading branch information
EmmettBicker authored and pytorchmergebot committed Dec 30, 2024
1 parent a7915c5 commit 92d8965
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 11 deletions.
292 changes: 286 additions & 6 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,22 @@ def _multistep_backprop_diff_hyperparams_fn(

# This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values
kwargs = kwargs.copy()

# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
if "beta1" in kwargs or "beta2" in kwargs:
# Prevent just one beta kwarg from being passed in
assert (
"beta1" in kwargs and "beta2" in kwargs
), "Both betas should be defined in kwargs"
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})

kwargs.update(
{k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
)
differentiable_kwargs = [
v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad
]
] + (list(kwargs["betas"]) if "betas" in kwargs else [])

criterion = nn.MSELoss()

Expand All @@ -104,18 +114,18 @@ def _multistep_backprop_diff_hyperparams_fn(
meta_loss = loss
meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True)

# Extra check to make sure the test properly computed a gradient for all kwargs
for kwarg in differentiable_kwargs:
assert kwarg.grad is not None

return (
(meta_loss,)
+ tuple(
v
for v in optimizer.state[params].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
+ tuple(
v
for v in kwargs.values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
+ tuple(differentiable_kwargs)
)


Expand Down Expand Up @@ -404,6 +414,276 @@ def test_radam(self):
),
)

def test_adam_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)

state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)

def test_adam_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)

state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)

def test_adam_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)

lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)

# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"lr": lr,
"differentiable": True,
}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)

def test_adam_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)

lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)

# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)

def test_adamw_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)

state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)

def test_adamw_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)

state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)

def test_adamw_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)

betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)

# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)

def test_adamw_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)

lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)

# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}

gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)

def test_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
Expand Down
50 changes: 45 additions & 5 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,14 @@ def _single_tensor_adam(
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(weight_decay, Tensor):
if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)

if torch.is_complex(param):
grad = torch.view_as_real(grad)
Expand All @@ -429,13 +436,43 @@ def _single_tensor_adam(
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - device_beta1)

exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
# Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
# Showing equivalence of differentiable path and nondifferentiable path
# expavg * b2 + grad^2 * (1-b2)
# add expavg * (1-b2) - expavg * (1-b2) = 0
# expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
# expavg - expavg * (1-b2) + grad^2 * (1-b2)
# expavg + (grad^2 - expavg) * (1-b2)
# expavg.lerp(grad^2, 1-beta2)
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

if capturable or differentiable:
step = step_t

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta1, Tensor):
if beta1.requires_grad:
bias_correction1 = 1 - beta1 ** step.clone()
else:
bias_correction1 = 1 - beta1**step
else:
bias_correction1 = 1 - beta1**step

# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
bias_correction2 = 1 - beta2 ** step.clone()
else:
bias_correction2 = 1 - beta2**step
else:
bias_correction2 = 1 - beta2**step

step_size = lr / bias_correction1
step_size_neg = step_size.neg()
Expand All @@ -462,7 +499,10 @@ def _single_tensor_adam(
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)

param.addcdiv_(exp_avg, denom)
if differentiable:
param.addcdiv_(exp_avg.clone(), denom)
else:
param.addcdiv_(exp_avg, denom)
else:
step = _get_value(step_t)

Expand Down

0 comments on commit 92d8965

Please sign in to comment.