Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Barlow Twins implementation #230

Closed

Conversation

OlivierDehaene
Copy link
Contributor

@OlivierDehaene OlivierDehaene commented Mar 10, 2021

Required (TBC)

  • BarlowTwinsLoss and Criterion
  • Documentation
    • Loss
    • SSL Approaches + Index
    • Model Zoo
    • Project
  • Default configs
    • pretrain
    • test/integration
    • debugging/pretrain
  • Benchmarks
    • ImageNet: 70.75 for 300 epochs
    • Imagenette 160: 88.8 Top1 accuracy

closes #229

@facebook-github-bot
Copy link
Contributor

Hi @OlivierDehaene!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2021
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

1 similar comment
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@OlivierDehaene OlivierDehaene mentioned this pull request Mar 11, 2021
3 tasks
@prigoyal prigoyal requested review from jzbontar and prigoyal March 11, 2021 14:27
@OlivierDehaene
Copy link
Contributor Author

@prigoyal, @jzbontar,

Using the following configuration for pretraining, and this one for evaluation, I obtain 85.7 Top 1 Accuracy on Imagenette 160.
Is it in the ballpark of what you would expect on this dataset?

If so, I think the BarlowTwinsLoss and Criterion are ready for review!

@QuentinDuval
Copy link
Contributor

QuentinDuval commented Mar 15, 2021

@prigoyal, @jzbontar,

Using the following configuration for pretraining, and this one for evaluation, I obtain 85.7 Top 1 Accuracy on Imagenette 160.
Is it in the ballpark of what you would expect on this dataset?

If so, I think the BarlowTwinsLoss and Criterion are ready for review!

I do not know for Barlow Twins specifically (@jzbontar can answer on that more precisely), but this is definitely in the same ball park than my experiments with SimCLR pre-trained and then evaluated on Imagenette 160 (where the typical results I got are between 85% to 88% top-1 accuracy) 👍

Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you so much @OlivierDehaene this is looking great. I have left some inline comments. Please take a look at them :) Also a next step I see:
Can you follow the https://github.com/facebookresearch/vissl/blob/master/.github/CONTRIBUTING.md#coding-style to try the dev/linter.sh on this PR ?

Afterwards, I think the open items in the PR checklist are next steps. We should aim to provide a config file for at least ImageNet. Further, a model in the model_zoo would be great .

vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 16, 2021

@prigoyal,

My implementation is based on the SimCLR loss implementation. A lot of the comments you added are relevant for both (for example the style comments). Do you want me to patch both at the same time?

@prigoyal
Copy link
Contributor

@prigoyal,

My implementation is based on the SimCLR loss implementation. A lot of the comments you added are relevant for both (for example the style comments). Do you want me to patch both at the same time?

indeed. I certainly welcome the improvements to SimCLR as well but in a separate PR :) thank you so much for this.

@prigoyal
Copy link
Contributor

@OlivierDehaene , I am pretty excited about this PR. I think we are extremely close to getting this merged and I appreciate you a lot for this work. :)

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 22, 2021

@prigoyal, @QuentinDuval,

The 2nd iteration of the criterion is done. The criterion now normalises the embedding and combines the cross-correlation matrix across all workers.

@prigoyal, do I resolve the comments that are solved? Or do you do it? What is the process here between maintainers and contributors?

@prigoyal
Copy link
Contributor

@prigoyal, @QuentinDuval,

The 2nd iteration of the criterion is done. The criterion now normalises the embedding and combines the cross-correlation matrix across all workers.

I implemented the normalisation in _sync_normalise instead of using torch.nn.SyncBatchNorm(..., affine=False) to not have an unnecessary train/eval state in the criterion (since BatchNorm changes its behaviour).

@prigoyal, do I resolve the comments that are solved? Or do you do it? What is the process here between maintainers and contributors?

thank you @OlivierDehaene , for the comments that pertain to the code changes here, contributors should resolve the comments once they have been addressed. If the comments require further discussion , they can stay open and we can discuss/maintainers can close :)

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 22, 2021

@jzbontar,

Excuse me if I'm wrong but I thought that torch.distributed.all_reduce cut gradients. However you use it here to sum the cross-correlation matrices. Wouldn't that cause issues in the backward pass?

@jzbontar
Copy link

jzbontar commented Mar 22, 2021

@jzbontar,

Excuse me if I'm wrong but I thought that torch.distributed.all_reduce cut gradients. However you use it here to sum the cross-correlation matrices. Wouldn't that cause issues in the backward pass?

Hm, good point. I checked whether torch.distributed.all_reduce cuts the gradients when I was writing code for Barlow Twins and I found that it doesn't and that the gradients were computed correctly. For example, when I run the following code on a machine with 2 GPUs

import torch

def main_worker(gpu):
    torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:12345', world_size=2, rank=gpu)
    torch.cuda.set_device(gpu)
    x = torch.ones(1).cuda().requires_grad_()
    xx = (gpu + 2) * x
    torch.distributed.all_reduce(xx)
    xx.backward()
    print(f'gpu={gpu}, grad={x.grad}')

if __name__ == '__main__':
    torch.multiprocessing.spawn(main_worker, nprocs=2)

it outputs

gpu=1, grad=tensor([3.], device='cuda:1')
gpu=0, grad=tensor([2.], device='cuda:0')

which seems okay to me.

The PyTorch documentation, however, suggests that the autograd-enabled communication primitives reside in torch.distributed.nn, so I probably should have used (and we should use in the VISSL implementation) torch.distributed.nn.all_reduce instead of torch.distributed.all_reduce.

Good catch, @odelalleau!

EDIT: whoops! I meant @OlivierDehaene. Sorry.

@odelalleau
Copy link

Good catch, @odelalleau!

Giving credit where it's due: good catch @OlivierDehaene! ;)

@prigoyal
Copy link
Contributor

prigoyal commented Mar 23, 2021

@OlivierDehaene , thank you for your continued awesome work on this PR :) Just wanted to chime in for an early recommendation for a next step (EDIT: seems like you have planned this already?): in order for us to ensure reproducibility , we should aim for an entry with Barlow Twins in the VISSL Model zoo. This entry should have a pre-trained model on ImageNet (cc @jzbontar) and we should list the evaluation accuracy of this model in the table.

For training on ImageNet, @jzbontar mentioned he might be able to help by running your code so it would be good to coordinate a plan for that. For the linear evaluation on ImageNet, please do reach out if you run into any issues or have any questions :)

Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking really good. Small inline comments and thank you @OlivierDehaene for your hard work on this :)

docs/source/ssl_approaches/barlow_twins.rst Outdated Show resolved Hide resolved
docs/source/ssl_approaches/barlow_twins.rst Outdated Show resolved Hide resolved
tests/test_losses.py Show resolved Hide resolved
tests/test_losses.py Show resolved Hide resolved
tests/test_losses.py Outdated Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Show resolved Hide resolved

class SyncNormalizeFunction(Function):
"""
Adapted from: https://github.com/NVIDIA/apex/blob/master/apex/parallel/sync_batchnorm.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious question: why can't we use the function directly from NVIDIA? if we want to not introduce dependency on apex and that's the reason, that's valid and in that case we could introduce vissl/utils/apex_helpers.py and put this function there. This will allow usage beyond this loss function. :)

Copy link
Contributor Author

@OlivierDehaene OlivierDehaene Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me there are several advantages to this:

  • I think it is important that the criterion is not dependant on train() and eval() cycles. Using a BatchNorm layer would add hidden running statistics that could lead to weird behaviour. It is possible to disable said statistics using track_running_stats =False, however this argument is bugged in the Apex implementation of SyncBatchNorm. See my issue here: [BUG] Apex otimized SyncBatchnormFunction crashes with track_running_stats=False NVIDIA/apex#1071.
  • We do not need to add additional logic to switch between PyTorch SyncBatchNorm and the faster Apex SyncBatchNorm depending on if Apex was installed by the user.
  • One could argue that it makes it clearer to the end user what is actually happening here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could introduce vissl/utils/apex_helpers.py and put this function there. This will allow usage beyond this loss function. :)

100 % agree. I wasn't sure where to add it. It also needs to be tested.

vissl/losses/barlow_twins_loss.py Show resolved Hide resolved
vissl/losses/barlow_twins_loss.py Outdated Show resolved Hide resolved
@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 23, 2021

@jzbontar, @QuentinDuval,

I did a rerun of the integration test on Imagenette 160 to check if the new version of the loss was working. The rerun achieved 88.8% Top1 accuracy.
I think we are getting close. :)

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 23, 2021

@jzbontar,

I saw your post on facebookresearch/barlowtwins#8 and I think there is still a weird interaction between torch.all_reduce and gradients.

For example, the following code doesn't yield the same gradients when using torch.all_reduce and GatherLayer, a specific gather op that doesn't cut gradients:

import torch

from vissl.utils.distributed_gradients import GatherLayer

n, m = 8, 4

def main_worker(i):
    torch.distributed.init_process_group(backend='gloo', init_method='tcp://localhost:12345', world_size=4, rank=i)
    x = torch.full((n, m), i+1, dtype=torch.float32).requires_grad_()

    # compute xTx using all_reduce
    xTx = x.T @ x
    torch.distributed.all_reduce(xTx)
    xTx.sum().backward()
    if i == 0:
        print(x.grad)

# tensor([[8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.],
#         [8., 8., 8., 8.]])

    x.grad.zero_()

    # compute xTx using GatherLayer
    z = torch.cat(GatherLayer.apply(x))
    zTz = z.T @ z
    zTz.sum().backward()
    if i == 0:
        print(x.grad)

# tensor([[32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.],
#         [32., 32., 32., 32.]])

# xTx == zTz
# tensor([[240., 240., 240., 240.],
#         [240., 240., 240., 240.],
#         [240., 240., 240., 240.],
#         [240., 240., 240., 240.]], grad_fn=<MmBackward>)


if __name__ == '__main__':
    torch.multiprocessing.spawn(main_worker, nprocs=4)

Could this also be related to your scaling issue?

EDIT: it may actually be an error in GatherLayer.
@prigoyal, @QuentinDuval, shouldn't there be a rescaling by 1/world_size here? Since the default reduce op is SUM. https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py#L27

@jzbontar
Copy link

jzbontar commented Mar 24, 2021

EDIT: it may actually be an error in GatherLayer. @prigoyal, @QuentinDuval, shouldn't there be a rescaling by 1/world_size here?

Yes, right? I think so too. The following example shows that the torch.distributed.all_reduce solution computes gradients that are the same as the baseline one-process solution. Whereas the loss in the torch.distributed.nn.all_reduce solution needs to be scaled by 1 / world_size.

import torch
import torch.distributed.nn

# compute gradient of xTy**2
# the gradient wrt x is 2 * xTy * y
# the gradient wrt y is 2 * xTy * x

def main_worker(rank):
    torch.distributed.init_process_group(backend='gloo', init_method='tcp://localhost:12345', world_size=4, rank=rank)

    # torch.distributed.all_reduce solution
    x = torch.randn(4).requires_grad_()
    y = torch.randn(4).requires_grad_()
    xTy = torch.dot(x, y)
    torch.distributed.all_reduce(xTy)
    loss = xTy.pow(2)
    loss.backward()
    assert torch.equal(x.grad, 2 * xTy * y) # this is correct and matches the baseline
    assert torch.equal(y.grad, 2 * xTy * x)

    # torch.distributed.nn.all_reduce solution
    x = torch.randn(4).requires_grad_()
    y = torch.randn(4).requires_grad_()
    xTy_1gpu = torch.dot(x, y)
    xTy = torch.distributed.nn.all_reduce(xTy_1gpu)
    loss = xTy.pow(2)
    loss /= 4  # have to divide loss by world_size to match the baseline solution
    loss.backward()
    assert torch.equal(x.grad, 2 * xTy * y)
    assert torch.equal(y.grad, 2 * xTy * x)

if __name__ == '__main__':
    # single process baseline
    x = torch.randn(16).requires_grad_()
    y = torch.randn(16).requires_grad_()
    xTy = torch.dot(x, y)
    loss = xTy.pow(2)
    loss.backward()
    assert torch.equal(x.grad, 2 * xTy * y)
    assert torch.equal(y.grad, 2 * xTy * x)

    torch.multiprocessing.spawn(main_worker, nprocs=4)

Maybe we should really be using reduce instead of all_reduce?

@prigoyal
Copy link
Contributor

@prigoyal, @QuentinDuval, shouldn't there be a rescaling by 1/world_size here? Since the default reduce op is SUM. https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py#L27

@OlivierDehaene , in VISSL, https://github.com/facebookresearch/vissl/blob/master/vissl/trainer/train_steps/standard_train_step.py#L158 is the place where we call all_reduce_mean from classy vision which divides by the world size https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/distributed_util.py#L69-L73

@facebook-github-bot
Copy link
Contributor

@OlivierDehaene has updated the pull request. You must reimport the pull request before landing.

@OlivierDehaene
Copy link
Contributor Author

@prigoyal,

I forgot I was on a fork and rebased on my fork instead of this one... Should be good now!

@facebook-github-bot
Copy link
Contributor

@prigoyal has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Apr 30, 2021

@prigoyal,

I can't see the details of the failing tests so IDK what to patch.

@prigoyal
Copy link
Contributor

@prigoyal,

I can't see the details of the failing tests so IDK what to patch.

Hi @OlivierDehaene , I will comment on relevant places in the PR shortly :)

Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some last remaining inline comments

Additionally, we need to run linter on the PR. Follow the instructions https://github.com/facebookresearch/vissl/blob/master/dev/README.md#practices-for-coding-quality to run the linter provided by VISSL and update the PR :)

@facebook-github-bot
Copy link
Contributor

@OlivierDehaene has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@OlivierDehaene has updated the pull request. You must reimport the pull request before landing.

@OlivierDehaene
Copy link
Contributor Author

@prigoyal,

Some lints fail on file that were not touched by this PR. I left them as is.

@prigoyal
Copy link
Contributor

@prigoyal,

Some lints fail on file that were not touched by this PR. I left them as is.

@OlivierDehaene , that sounds great. I'll work on linting the other files :)

@OlivierDehaene
Copy link
Contributor Author

Sorry I misunderstood what you meant with TRUNK and TRUNK_PARAMS. It should be good now!

@facebook-github-bot
Copy link
Contributor

@prigoyal has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Apr 30, 2021
Summary:
## Required (TBC)

- [X] BarlowTwinsLoss and Criterion
- [x] Documentation
  - [X] Loss
  - [x] SSL Approaches + Index
  - [x] Model Zoo
  - [x] Project
- [x] Default configs
    - [x] pretrain
    - [X] test/integration
    - [X] debugging/pretrain
- [x] Benchmarks
  - [x] ImageNet: 70.75 for 300 epochs
  - [x] Imagenette 160: 88.8 Top1 accuracy

closes #229

Pull Request resolved: #230

Reviewed By: iseessel

Differential Revision: D28118605

Pulled By: prigoyal

fbshipit-source-id: 4436d6fd9d115b80ef5c5396318caa3cb26faadb
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Barlow Twins
7 participants