Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DCCRN and DCUNet-Large #276

Merged
merged 24 commits into from
Nov 24, 2020
Merged

Fix DCCRN and DCUNet-Large #276

merged 24 commits into from
Nov 24, 2020

Conversation

jonashaag
Copy link
Collaborator

This is the real DCCRN architecture. I got some of the details wrong in the initial version.

  • Are we simply going to release the fixed version without any backwards compat? I think that's reasonable since we have no pretrained models so far and it's only been a few days since the release.
  • The code the concatenative skip connection with the intermediate layer is a draft. I'm not sure how to do it in a clean way just yet. The way the intermediate layer works in this new version requires that we are aware of its existence in the u-net params calculation, so we can't really keep them entirely separate anymore. Maybe we have to add the concept of this concatenative intermediate layer to the u-net param calculation code. If you have any ideas that would be helpful as well!

@jonashaag
Copy link
Collaborator Author

@kssk16 Note that our LSTM is of size 640 rather than 512 because we're not stripping the Nyquist frequency so have 257 rather than 256 frequency bins.

Maybe we should change this as well so the implementation matches the paper.

@kssk16
Copy link

kssk16 commented Oct 27, 2020

@kssk16 Note that our LSTM is of size 640 rather than 512 because we're not stripping the Nyquist frequency so have 257 rather than 256 frequency bins.

Maybe we should change this as well so the implementation matches the paper.

@jonashaag Thanks for letting me know.

@jonashaag
Copy link
Collaborator Author

jonashaag commented Oct 27, 2020

More differences:

  • They use PReLu, not Leaky ReLu
  • They use (2, 1) padding for all encoders and (2, 0) padding for all decoders
  • In the output layer they use output_padding=(1, 0)
  • They use a window of size 400, pad it (centered) to 512 and then use a 512 FFT. We don't do the 400 thing.

@mpariente mpariente mentioned this pull request Nov 2, 2020
3 tasks
@jonashaag
Copy link
Collaborator Author

OK, this should be ready to review.

The DCCRN implementation isn't 100% identical to the paper one, but very similar. Differences I know of: using 512 FFT instead of 400 + 2 * 56 zero padding; slightly different padding in T dimension.

While fixing the DCCRN implementation I realized that you don't need the intermediate_layer concept because you can simply replace it by an "asymmetric" u-net structure where the deepest layer's "partner" is the identity. In fact, the intermediate_layer concept doesn't actually fit the DCCRN architecture, so I removed it.

I also included a fix to the DCUNet architectures that I realized was wrong.

In general, I'd appreciate if someone could compare my implementations to the papers because I'm quite unsure if there are more differences that I didn't spot. Code can be found here https://paperswithcode.com/paper/dccrn-deep-complex-convolution-recurrent-1 https://paperswithcode.com/paper/phase-aware-speech-enhancement-with-deep-1

Commit fc0232c is sort of unrelated, we can exclude it.

@jonashaag jonashaag changed the title [Draft] Fix DCCRN Fix DCCRN Nov 19, 2020
@jonashaag jonashaag changed the title Fix DCCRN Fix DCCRN and DCUNet Nov 19, 2020
@jonashaag jonashaag changed the title Fix DCCRN and DCUNet Fix DCCRN and DCUNet-Large Nov 19, 2020
@jonashaag jonashaag marked this pull request as ready for review November 19, 2020 10:20
@mpariente
Copy link
Collaborator

Maybe @kssk16 @huyanxin @sweetcocoa and @chanil1218 would like to have a look? That would be nice!

Copy link
Collaborator

@mpariente mpariente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great IMO !
I didn't check the implementation details and the correspondence with the original code very thoroughly though.

asteroid/complex_nn.py Outdated Show resolved Hide resolved
asteroid/models/dccrnet.py Outdated Show resolved Hide resolved
asteroid/models/dcunet.py Show resolved Hide resolved
tests/jit/jit_models_test.py Outdated Show resolved Hide resolved
@@ -178,7 +194,7 @@ def get_default_device():
"n_filters": 32,
"kernel_size": 21,
}
model = model_def(**params, fb_name="free")
model = model_def(**params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'll need to remove that after this is merged..

tests/jit/jit_models_test.py Outdated Show resolved Hide resolved
tests/jit/jit_models_test.py Outdated Show resolved Hide resolved
tests/jit/jit_models_test.py Outdated Show resolved Hide resolved
jonashaag and others added 2 commits November 20, 2020 14:07
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
Error was RuntimeError: Tracer cannot set value trace for type Bool. Supported types are tensor, tensor list, and tuple of tensors.
@mpariente
Copy link
Collaborator

If I print in the BaseUNet's forward in this way

    def forward(self, x):
        enc_outs = []
        for idx, enc in enumerate(self.encoders):
            x = enc(x)
            enc_outs.append(x)
            print(f"Encoding {idx}", x.shape)
        for idx, (enc_out, dec) in enumerate(zip(reversed(enc_outs[:-1]), self.decoders)):
            x = dec(x)
            print(f"Decoding {idx}", x.shape)
            x = torch.cat([x, enc_out], dim=1)
        return self.output_layer(x)

I get

print(tf_rep.shape) torch.Size([1, 256, 36])
Encoding 0 torch.Size([1, 16, 128, 35])
Encoding 1 torch.Size([1, 32, 64, 34])
Encoding 2 torch.Size([1, 64, 32, 33])
Encoding 3 torch.Size([1, 128, 16, 32])
Encoding 4 torch.Size([1, 128, 8, 31])
Encoding 5 torch.Size([1, 128, 4, 30])
Encoding 6 torch.Size([1, 128, 4, 30])
Decoding 0 torch.Size([1, 128, 4, 30])
Decoding 1 torch.Size([1, 128, 8, 31])
Decoding 2 torch.Size([1, 128, 16, 32])
Decoding 3 torch.Size([1, 64, 32, 33])
Decoding 4 torch.Size([1, 32, 64, 34])
Decoding 5 torch.Size([1, 16, 128, 35])

where the time dimension is not divided by two at every iteration, but just looses one. Is this intended? At least this doesn't seem consistent with the docstring and also your comments that suggest that the time dim has to be divisible by the prod of strides. Are the strides on both freq and time dimension?

By the way, for DCCRNet, the jit test passes even with different size, which is great news !

@jonashaag
Copy link
Collaborator Author

Could it be that you confused DCUNet and DCCRNet here?

$ python -c 'import asteroid.models as m, torch; m.DCUNet("DCUNet-10")(torch.randn(1,50000))'
Encoding 0 torch.Size([1, 32, 129, 97])
Encoding 1 torch.Size([1, 64, 65, 49])
Encoding 2 torch.Size([1, 64, 33, 25])
Encoding 3 torch.Size([1, 64, 17, 13])
Encoding 4 torch.Size([1, 64, 9, 13])
Decoding 0 torch.Size([1, 64, 17, 13])
Decoding 1 torch.Size([1, 64, 33, 25])
Decoding 2 torch.Size([1, 64, 65, 49])
Decoding 3 torch.Size([1, 32, 129, 97])
...

The paddings are different in DCUNet. In DCCRNet we use the "official" padding (or very similar), that is, the encoders each lose 1 frame in the T dim and the decoders add it back using output padding.

In DCUNet, there is not official implementation available, so I used what's most common with other implementations, which seems to be padding so that for each encoder, T -> (T+1)/2 roughly, for example 17 -> 9 -> 5 -> 3 -> 2 -> 2. We can also use some other padding, I'm not an expert in u-nets (not expert in anything in deep learning really 😅)

@mpariente
Copy link
Collaborator

It is completely possible, yes !
I'm not an expert in either of these architectures and didn't read the papers in detail recently.

Given that DCCRNet is much more flexible, I guess it makes sense to isolate the tests of DCUNet, right?
I'll do that

@mpariente
Copy link
Collaborator

IMO, this should be good.

I found the correct durations for DCUNet by hand as well, so we test 3 shapes and the traced model is consistent, which is great !

Few things:

  • Having to guess input shape is very bad, and the error message is half helpful because it is in the frequency domain, so it doesn't directly help to get the right shape in the time domain. How can we improve that? Do we want to improve that?
  • Those tests are pretty slow, so it would be nice to have "mini" architectures.

image

Extending DCCRNet to separation shouldn't be too hard, we just need to overwrite the second entry of the last line in the architecture file that is past to the complex conv, right. So we could easily add a n_src arg to it and enable separation. Am I missing something?

image

@jonashaag
Copy link
Collaborator Author

Ok, done! I also changed the args specifications to be all tuples because I accidentally modified some of the lists because they are mutable, which is a terrible thing to debug.

@jonashaag
Copy link
Collaborator Author

I created a PR against PyTorch for better version checks pytorch/pytorch#48414

asteroid/masknn/base.py Outdated Show resolved Hide resolved
asteroid/models/dccrnet.py Show resolved Hide resolved
asteroid/utils/test_utils.py Outdated Show resolved Hide resolved
jonashaag and others added 2 commits November 24, 2020 11:18
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
@mpariente
Copy link
Collaborator

I'll review the whole PR again today but I think it's ready to be merged !

@jonashaag
Copy link
Collaborator Author

I'll review the whole PR again today

Thank you! Much appreciated. Somehow this PR got MUCH bigger than expected :-D

@jonashaag
Copy link
Collaborator Author

Found missing shape check for DCCRN, added tests for shape checks, and also fixed the PyTorch 1.8 issue (it only breaks if you do foo[:x] where x = len(foo), ie. a noop slice)

Comment on lines 151 to 161
# DCUMaskNet should fail with wrong freqency dimensions
DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
with pytest.raises(TypeError):
DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64))

# DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used
DCUNet("mini", fix_length_mode="pad").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
DCUNet("mini", fix_length_mode="trim").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
with pytest.raises(TypeError):
DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool !

@mpariente
Copy link
Collaborator

So the trimming padding can be done with a single function, trim_x_to_y and pad_or_trim_x_to_y are not needed it seems.

@mpariente
Copy link
Collaborator

It's quite dangerous to push to the same branch in the same time (my bad 😅 ) but we did well ^^

@jonashaag
Copy link
Collaborator Author

So the trimming padding can be done with a single function

Indeed it seems so! https://github.com/pytorch/pytorch/blob/367426494759ddde0896665ed55c6f9af2870cf0/aten/src/ATen/native/ConstantPadNd.cpp#L25

@jonashaag
Copy link
Collaborator Author

So we could remove the trim functions and change the pad docstring to say "right-pad or right-trim".

Btw I won't be pushing any code for the next few hours, so feel free to push :P

@mpariente
Copy link
Collaborator

I'll do that.

@mpariente
Copy link
Collaborator

I re-reviewed everything and it LGTM
Waiting for your approval to merge it.

@jonashaag
Copy link
Collaborator Author

Also did a quick review, LGTM, go ahead :)

@mpariente mpariente merged commit cbd0bd0 into master Nov 24, 2020
@mpariente mpariente deleted the fix-dccrn branch November 24, 2020 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants