-
Notifications
You must be signed in to change notification settings - Fork 334
Barlow Twins implementation #230
Barlow Twins implementation #230
Conversation
Hi @OlivierDehaene! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
1 similar comment
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
Using the following configuration for pretraining, and this one for evaluation, I obtain 85.7 Top 1 Accuracy on Imagenette 160. If so, I think the BarlowTwinsLoss and Criterion are ready for review! |
I do not know for Barlow Twins specifically (@jzbontar can answer on that more precisely), but this is definitely in the same ball park than my experiments with SimCLR pre-trained and then evaluated on Imagenette 160 (where the typical results I got are between 85% to 88% top-1 accuracy) 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you so much @OlivierDehaene this is looking great. I have left some inline comments. Please take a look at them :) Also a next step I see:
Can you follow the https://github.com/facebookresearch/vissl/blob/master/.github/CONTRIBUTING.md#coding-style to try the dev/linter.sh on this PR ?
Afterwards, I think the open items in the PR checklist are next steps. We should aim to provide a config file for at least ImageNet. Further, a model in the model_zoo would be great .
My implementation is based on the SimCLR loss implementation. A lot of the comments you added are relevant for both (for example the style comments). Do you want me to patch both at the same time? |
indeed. I certainly welcome the improvements to SimCLR as well but in a separate PR :) thank you so much for this. |
@OlivierDehaene , I am pretty excited about this PR. I think we are extremely close to getting this merged and I appreciate you a lot for this work. :) |
The 2nd iteration of the criterion is done. The criterion now normalises the embedding and combines the cross-correlation matrix across all workers. @prigoyal, do I resolve the comments that are solved? Or do you do it? What is the process here between maintainers and contributors? |
thank you @OlivierDehaene , for the comments that pertain to the code changes here, contributors should resolve the comments once they have been addressed. If the comments require further discussion , they can stay open and we can discuss/maintainers can close :) |
Hm, good point. I checked whether import torch
def main_worker(gpu):
torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:12345', world_size=2, rank=gpu)
torch.cuda.set_device(gpu)
x = torch.ones(1).cuda().requires_grad_()
xx = (gpu + 2) * x
torch.distributed.all_reduce(xx)
xx.backward()
print(f'gpu={gpu}, grad={x.grad}')
if __name__ == '__main__':
torch.multiprocessing.spawn(main_worker, nprocs=2) it outputs
which seems okay to me. The PyTorch documentation, however, suggests that the autograd-enabled communication primitives reside in Good catch, @odelalleau! EDIT: whoops! I meant @OlivierDehaene. Sorry. |
Giving credit where it's due: good catch @OlivierDehaene! ;) |
@OlivierDehaene , thank you for your continued awesome work on this PR :) Just wanted to chime in for an early recommendation for a next step (EDIT: seems like you have planned this already?): in order for us to ensure reproducibility , we should aim for an entry with Barlow Twins in the VISSL Model zoo. This entry should have a pre-trained model on ImageNet (cc @jzbontar) and we should list the evaluation accuracy of this model in the table. For training on ImageNet, @jzbontar mentioned he might be able to help by running your code so it would be good to coordinate a plan for that. For the linear evaluation on ImageNet, please do reach out if you run into any issues or have any questions :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking really good. Small inline comments and thank you @OlivierDehaene for your hard work on this :)
vissl/losses/barlow_twins_loss.py
Outdated
|
||
class SyncNormalizeFunction(Function): | ||
""" | ||
Adapted from: https://github.com/NVIDIA/apex/blob/master/apex/parallel/sync_batchnorm.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious question: why can't we use the function directly from NVIDIA? if we want to not introduce dependency on apex and that's the reason, that's valid and in that case we could introduce vissl/utils/apex_helpers.py
and put this function there. This will allow usage beyond this loss function. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me there are several advantages to this:
- I think it is important that the criterion is not dependant on
train()
andeval()
cycles. Using a BatchNorm layer would add hidden running statistics that could lead to weird behaviour. It is possible to disable said statistics usingtrack_running_stats =False
, however this argument is bugged in the Apex implementation of SyncBatchNorm. See my issue here: [BUG] Apex otimized SyncBatchnormFunction crashes with track_running_stats=False NVIDIA/apex#1071. - We do not need to add additional logic to switch between PyTorch SyncBatchNorm and the faster Apex SyncBatchNorm depending on if Apex was installed by the user.
- One could argue that it makes it clearer to the end user what is actually happening here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could introduce vissl/utils/apex_helpers.py and put this function there. This will allow usage beyond this loss function. :)
100 % agree. I wasn't sure where to add it. It also needs to be tested.
I did a rerun of the integration test on Imagenette 160 to check if the new version of the loss was working. The rerun achieved 88.8% Top1 accuracy. |
I saw your post on facebookresearch/barlowtwins#8 and I think there is still a weird interaction between For example, the following code doesn't yield the same gradients when using import torch
from vissl.utils.distributed_gradients import GatherLayer
n, m = 8, 4
def main_worker(i):
torch.distributed.init_process_group(backend='gloo', init_method='tcp://localhost:12345', world_size=4, rank=i)
x = torch.full((n, m), i+1, dtype=torch.float32).requires_grad_()
# compute xTx using all_reduce
xTx = x.T @ x
torch.distributed.all_reduce(xTx)
xTx.sum().backward()
if i == 0:
print(x.grad)
# tensor([[8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.],
# [8., 8., 8., 8.]])
x.grad.zero_()
# compute xTx using GatherLayer
z = torch.cat(GatherLayer.apply(x))
zTz = z.T @ z
zTz.sum().backward()
if i == 0:
print(x.grad)
# tensor([[32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.],
# [32., 32., 32., 32.]])
# xTx == zTz
# tensor([[240., 240., 240., 240.],
# [240., 240., 240., 240.],
# [240., 240., 240., 240.],
# [240., 240., 240., 240.]], grad_fn=<MmBackward>)
if __name__ == '__main__':
torch.multiprocessing.spawn(main_worker, nprocs=4) Could this also be related to your scaling issue? EDIT: it may actually be an error in GatherLayer. |
Yes, right? I think so too. The following example shows that the import torch
import torch.distributed.nn
# compute gradient of xTy**2
# the gradient wrt x is 2 * xTy * y
# the gradient wrt y is 2 * xTy * x
def main_worker(rank):
torch.distributed.init_process_group(backend='gloo', init_method='tcp://localhost:12345', world_size=4, rank=rank)
# torch.distributed.all_reduce solution
x = torch.randn(4).requires_grad_()
y = torch.randn(4).requires_grad_()
xTy = torch.dot(x, y)
torch.distributed.all_reduce(xTy)
loss = xTy.pow(2)
loss.backward()
assert torch.equal(x.grad, 2 * xTy * y) # this is correct and matches the baseline
assert torch.equal(y.grad, 2 * xTy * x)
# torch.distributed.nn.all_reduce solution
x = torch.randn(4).requires_grad_()
y = torch.randn(4).requires_grad_()
xTy_1gpu = torch.dot(x, y)
xTy = torch.distributed.nn.all_reduce(xTy_1gpu)
loss = xTy.pow(2)
loss /= 4 # have to divide loss by world_size to match the baseline solution
loss.backward()
assert torch.equal(x.grad, 2 * xTy * y)
assert torch.equal(y.grad, 2 * xTy * x)
if __name__ == '__main__':
# single process baseline
x = torch.randn(16).requires_grad_()
y = torch.randn(16).requires_grad_()
xTy = torch.dot(x, y)
loss = xTy.pow(2)
loss.backward()
assert torch.equal(x.grad, 2 * xTy * y)
assert torch.equal(y.grad, 2 * xTy * x)
torch.multiprocessing.spawn(main_worker, nprocs=4) Maybe we should really be using |
@OlivierDehaene , in VISSL, https://github.com/facebookresearch/vissl/blob/master/vissl/trainer/train_steps/standard_train_step.py#L158 is the place where we call |
… gradients and save one all-to-all communication.
Added 300epochs model
dfa33e5
to
a5ba52b
Compare
@OlivierDehaene has updated the pull request. You must reimport the pull request before landing. |
I forgot I was on a fork and rebased on my fork instead of this one... Should be good now! |
@prigoyal has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
I can't see the details of the failing tests so IDK what to patch. |
Hi @OlivierDehaene , I will comment on relevant places in the PR shortly :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some last remaining inline comments
Additionally, we need to run linter on the PR. Follow the instructions https://github.com/facebookresearch/vissl/blob/master/dev/README.md#practices-for-coding-quality to run the linter provided by VISSL and update the PR :)
configs/config/pretrain/barlow_twins/barlow_twins_4node_resnet.yaml
Outdated
Show resolved
Hide resolved
configs/config/debugging/pretrain/barlow_twins/barlow_twins_1node_resnet_imagenette_160.yaml
Outdated
Show resolved
Hide resolved
configs/config/debugging/pretrain/barlow_twins/barlow_twins_1node_resnet_imagenette_160.yaml
Outdated
Show resolved
Hide resolved
configs/config/pretrain/barlow_twins/barlow_twins_4node_resnet.yaml
Outdated
Show resolved
Hide resolved
@OlivierDehaene has updated the pull request. You must reimport the pull request before landing. |
@OlivierDehaene has updated the pull request. You must reimport the pull request before landing. |
Some lints fail on file that were not touched by this PR. I left them as is. |
@OlivierDehaene , that sounds great. I'll work on linting the other files :) |
Sorry I misunderstood what you meant with |
@prigoyal has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: ## Required (TBC) - [X] BarlowTwinsLoss and Criterion - [x] Documentation - [X] Loss - [x] SSL Approaches + Index - [x] Model Zoo - [x] Project - [x] Default configs - [x] pretrain - [X] test/integration - [X] debugging/pretrain - [x] Benchmarks - [x] ImageNet: 70.75 for 300 epochs - [x] Imagenette 160: 88.8 Top1 accuracy closes #229 Pull Request resolved: #230 Reviewed By: iseessel Differential Revision: D28118605 Pulled By: prigoyal fbshipit-source-id: 4436d6fd9d115b80ef5c5396318caa3cb26faadb
Required (TBC)
closes #229