Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unification] Generalize TransformerEncoder #240

Closed
wants to merge 9 commits into from

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Aug 3, 2022

Stack from ghstack (oldest at bottom):

Differential Revision: D38506881

Summary

Add a general TransformerEncoder class that simply stacks n layers of our custom TransformerEncoderLayer. Repurpose FLAVATransformerOutput and use it as TransformerOutput for this class.

Test plan

Newly added unit tests, pytest test/modules/layers/test_transformer.py -vv

===================================================== test session starts ======================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 5 items                                                                                                              

test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                        [ 20%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                       [ 40%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                 [ 60%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                [ 80%]
test/modules/layers/test_transformer.py::TestTransformerEncoder::test_forward PASSED                                     [100%]

====================================================== 5 passed in 1.59s =======================================================

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 3, 2022
@RdoubleA RdoubleA marked this pull request as draft August 3, 2022 16:14
RdoubleA added a commit that referenced this pull request Aug 3, 2022
ghstack-source-id: 6971e2c9faff1679bb6e19c93bc12492d7f8f59d
Pull Request resolved: #240
RdoubleA added a commit that referenced this pull request Aug 3, 2022
ghstack-source-id: 5372493620f0413999840da896ab072b38e571ed
Pull Request resolved: #240
@RdoubleA RdoubleA marked this pull request as ready for review August 3, 2022 17:31
@codecov-commenter
Copy link

codecov-commenter commented Aug 5, 2022

Codecov Report

❗ No coverage uploaded for pull request base (gh/RdoubleA/29/base@d8d9305). Click here to learn what that means.
The diff coverage is n/a.

@@                  Coverage Diff                   @@
##             gh/RdoubleA/29/base     #240   +/-   ##
======================================================
  Coverage                       ?   92.16%           
======================================================
  Files                          ?       50           
  Lines                          ?     3077           
  Branches                       ?        0           
======================================================
  Hits                           ?     2836           
  Misses                         ?      241           
  Partials                       ?        0           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

from torchmultimodal.utils.common import get_clones


TransformerOutput = namedtuple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just uniformly subclass from NamedTuple for model outputs?

if return_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_head_mask = head_mask[i] if head_mask is not None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_masks can be different across the layers?

return_attn_weights=True,
)

hidden_states = layer_outputs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If layer_outputs only contains the hidden_states, this will try to index the 0-th element instead of returning the full hidden_states, right?

        if return_attn_weights:
            return outputs, attn_weights
        else:
            return outputs

Consider adding a unit test when the kwargs args are False as well?

output = encoder(inputs, return_hidden_states=True, return_attn_weights=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this has been addressed as the encoder layer doesn't always return two tensors, right?

layer_outputs, attn_weights = layer_module(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_attn_weights is set to True so it will guarantee returning two outputs

Copy link
Contributor

@langong347 langong347 Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we should probably ask TransformerEncoderLayer to always return two outputs anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, maybe I'll modify to handle both single and double outputs

@@ -22,3 +23,13 @@ def forward(self, x: Tensor) -> Tensor:
self.eps,
)
return output.type_as(x)


def fp32layernorm(x: Tensor, layernorm: nn.Module) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need the functional besides class Fp32LayerNorm?

@RdoubleA
Copy link
Contributor Author

RdoubleA commented Aug 8, 2022

@RdoubleA has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Differential Revision: [D38506881](https://our.internmc.facebook.com/intern/diff/D38506881)

## Summary
Add a general `TransformerEncoder` class that simply stacks `n` layers of our custom `TransformerEncoderLayer`. Repurpose `FLAVATransformerOutput` and use it as `TransformerOutput` for this class.

## Test plan
Newly added unit tests, `pytest test/modules/layers/test_transformer.py -vv`
```
===================================================== test session starts ======================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 5 items                                                                                                              

test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                        [ 20%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                       [ 40%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                 [ 60%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                [ 80%]
test/modules/layers/test_transformer.py::TestTransformerEncoder::test_forward PASSED                                     [100%]

====================================================== 5 passed in 1.59s =======================================================
```

[ghstack-poisoned]
@RdoubleA
Copy link
Contributor Author

RdoubleA commented Aug 9, 2022

@RdoubleA has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

for _ in range(n_layer)
]
)
self.num_layers = n_layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I think it was used in CLIPTextEncoder. I won't be able to get to unifying that component but maybe we should keep self.num_layers for when we do?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the caller can maintain the num layers. no need to add unused instance variables here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aight I'll remove it

all_hidden_states: Tuple[Tensor, ...] = () if return_hidden_states else None
all_self_attentions: Tuple[Tensor, ...] = () if return_attn_weights else None

for i, layer_module in enumerate(self.layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you dont need enumeration


for i, layer_module in enumerate(self.layer):
if return_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm will this add the original input hidden states too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is straight from the original FLAVATransformerEncoder:

all_hidden_states.append(hidden_states)

Differential Revision: [D38506881](https://our.internmc.facebook.com/intern/diff/D38506881)

## Summary
Add a general `TransformerEncoder` class that simply stacks `n` layers of our custom `TransformerEncoderLayer`. Repurpose `FLAVATransformerOutput` and use it as `TransformerOutput` for this class.

## Test plan
Newly added unit tests, `pytest test/modules/layers/test_transformer.py -vv`
```
===================================================== test session starts ======================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 5 items                                                                                                              

test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                        [ 20%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                       [ 40%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                 [ 60%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                [ 80%]
test/modules/layers/test_transformer.py::TestTransformerEncoder::test_forward PASSED                                     [100%]

====================================================== 5 passed in 1.59s =======================================================
```

[ghstack-poisoned]
Differential Revision: [D38506881](https://our.internmc.facebook.com/intern/diff/D38506881)

## Summary
Add a general `TransformerEncoder` class that simply stacks `n` layers of our custom `TransformerEncoderLayer`. Repurpose `FLAVATransformerOutput` and use it as `TransformerOutput` for this class.

## Test plan
Newly added unit tests, `pytest test/modules/layers/test_transformer.py -vv`
```
===================================================== test session starts ======================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 5 items                                                                                                              

test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                        [ 20%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                       [ 40%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                 [ 60%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                [ 80%]
test/modules/layers/test_transformer.py::TestTransformerEncoder::test_forward PASSED                                     [100%]

====================================================== 5 passed in 1.59s =======================================================
```

[ghstack-poisoned]
@RdoubleA
Copy link
Contributor Author

@RdoubleA has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Differential Revision: [D38506881](https://our.internmc.facebook.com/intern/diff/D38506881)

## Summary
Add a general `TransformerEncoder` class that simply stacks `n` layers of our custom `TransformerEncoderLayer`. Repurpose `FLAVATransformerOutput` and use it as `TransformerOutput` for this class.

## Test plan
Newly added unit tests, `pytest test/modules/layers/test_transformer.py -vv`
```
===================================================== test session starts ======================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 5 items                                                                                                              

test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                        [ 20%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                       [ 40%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                 [ 60%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                [ 80%]
test/modules/layers/test_transformer.py::TestTransformerEncoder::test_forward PASSED                                     [100%]

====================================================== 5 passed in 1.59s =======================================================
```

[ghstack-poisoned]
@RdoubleA
Copy link
Contributor Author

@RdoubleA has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot deleted the gh/RdoubleA/29/head branch August 15, 2022 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants