-
Notifications
You must be signed in to change notification settings - Fork 23.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding support for differentiable lr, weight_decay, and betas in Adam…
…/AdamW (#143726) Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn) This is also going to probably be my last PR before the holidays! Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged. This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py) I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this ```py if differentiable and isinstance(beta2, Tensor): if beta2.requires_grad: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` That I could definitely simplify to just ```py if differentiable and isinstance(beta2, Tensor): exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability. Also the line in the above example: ```py exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) ``` was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it. Further work on #141832 Pull Request resolved: #143726 Approved by: https://github.com/janeyx99
- Loading branch information
1 parent
a7915c5
commit 92d8965
Showing
2 changed files
with
331 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters