-
Notifications
You must be signed in to change notification settings - Fork 427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix DCCRN and DCUNet-Large #276
Conversation
@kssk16 Note that our LSTM is of size 640 rather than 512 because we're not stripping the Nyquist frequency so have 257 rather than 256 frequency bins. Maybe we should change this as well so the implementation matches the paper. |
@jonashaag Thanks for letting me know. |
More differences:
|
OK, this should be ready to review. The DCCRN implementation isn't 100% identical to the paper one, but very similar. Differences I know of: using 512 FFT instead of 400 + 2 * 56 zero padding; slightly different padding in T dimension. While fixing the DCCRN implementation I realized that you don't need the I also included a fix to the DCUNet architectures that I realized was wrong. In general, I'd appreciate if someone could compare my implementations to the papers because I'm quite unsure if there are more differences that I didn't spot. Code can be found here https://paperswithcode.com/paper/dccrn-deep-complex-convolution-recurrent-1 https://paperswithcode.com/paper/phase-aware-speech-enhancement-with-deep-1 Commit fc0232c is sort of unrelated, we can exclude it. |
Maybe @kssk16 @huyanxin @sweetcocoa and @chanil1218 would like to have a look? That would be nice! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great IMO !
I didn't check the implementation details and the correspondence with the original code very thoroughly though.
tests/jit/jit_models_test.py
Outdated
@@ -178,7 +194,7 @@ def get_default_device(): | |||
"n_filters": 32, | |||
"kernel_size": 21, | |||
} | |||
model = model_def(**params, fb_name="free") | |||
model = model_def(**params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I'll need to remove that after this is merged..
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
Error was RuntimeError: Tracer cannot set value trace for type Bool. Supported types are tensor, tensor list, and tuple of tensors.
If I print in the BaseUNet's forward in this way def forward(self, x):
enc_outs = []
for idx, enc in enumerate(self.encoders):
x = enc(x)
enc_outs.append(x)
print(f"Encoding {idx}", x.shape)
for idx, (enc_out, dec) in enumerate(zip(reversed(enc_outs[:-1]), self.decoders)):
x = dec(x)
print(f"Decoding {idx}", x.shape)
x = torch.cat([x, enc_out], dim=1)
return self.output_layer(x) I get
where the time dimension is not divided by two at every iteration, but just looses one. Is this intended? At least this doesn't seem consistent with the docstring and also your comments that suggest that the time dim has to be divisible by the prod of strides. Are the strides on both freq and time dimension? By the way, for DCCRNet, the jit test passes even with different size, which is great news ! |
Could it be that you confused DCUNet and DCCRNet here?
The paddings are different in DCUNet. In DCCRNet we use the "official" padding (or very similar), that is, the encoders each lose 1 frame in the T dim and the decoders add it back using output padding. In DCUNet, there is not official implementation available, so I used what's most common with other implementations, which seems to be padding so that for each encoder, |
It is completely possible, yes ! Given that DCCRNet is much more flexible, I guess it makes sense to isolate the tests of DCUNet, right? |
IMO, this should be good. I found the correct durations for DCUNet by hand as well, so we test 3 shapes and the traced model is consistent, which is great ! Few things:
Extending DCCRNet to separation shouldn't be too hard, we just need to overwrite the second entry of the last line in the architecture file that is past to the complex conv, right. So we could easily add a |
…rough the intended usage: the main architectures
Ok, done! I also changed the args specifications to be all tuples because I accidentally modified some of the lists because they are mutable, which is a terrible thing to debug. |
I created a PR against PyTorch for better version checks pytorch/pytorch#48414 |
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
I'll review the whole PR again today but I think it's ready to be merged ! |
Thank you! Much appreciated. Somehow this PR got MUCH bigger than expected :-D |
This reverts commit 4f0a2c6.
Found missing shape check for DCCRN, added tests for shape checks, and also fixed the PyTorch 1.8 issue (it only breaks if you do |
tests/models/models_test.py
Outdated
# DCUMaskNet should fail with wrong freqency dimensions | ||
DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) | ||
with pytest.raises(TypeError): | ||
DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64)) | ||
|
||
# DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used | ||
DCUNet("mini", fix_length_mode="pad").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) | ||
DCUNet("mini", fix_length_mode="trim").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) | ||
with pytest.raises(TypeError): | ||
DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool !
So the trimming padding can be done with a single function, |
It's quite dangerous to push to the same branch in the same time (my bad 😅 ) but we did well ^^ |
Indeed it seems so! https://github.com/pytorch/pytorch/blob/367426494759ddde0896665ed55c6f9af2870cf0/aten/src/ATen/native/ConstantPadNd.cpp#L25 |
So we could remove the trim functions and change the pad docstring to say "right-pad or right-trim". Btw I won't be pushing any code for the next few hours, so feel free to push :P |
I'll do that. |
I re-reviewed everything and it LGTM |
Also did a quick review, LGTM, go ahead :) |
This is the real DCCRN architecture. I got some of the details wrong in the initial version.