-
Notifications
You must be signed in to change notification settings - Fork 427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixture Invariant Training #320
Conversation
added mix_it and mix_it_gen loss wrappers
add test for mix_it and mix_it_gen
Thank you very much for your pull request ! I have also a very naive implementation of MixIT on my local machine but you beat me on time ;) ! I think we can merge the two together (especially for the mixing matrix generation part). Also @mpariente how do you feel having mixit as an extension of current |
Yes sure we can definitely merge them! I didn't focus to much on optimizing it and maybe you've a faster implementation. I tested it with the MUSDB18 dataset and seems to work! |
Thank you very much for the implementation, that's great ! PIT and MixIT are essentially different training paradigm (though related), so I also think making a separate class would be a good idea. It seems that the first commit in this PR reverts recent changes made to Note to self: we start to have several extensions of |
Yes, I am sorry I was working on an old version of the file and didn't realize it until the PR. I'll fix this |
updated to the last version of black compliance, doc rendering and others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much better now, thank you!
There are few things to make before merging it
- Re-add the tests that were deleted,
- Making a separate class (
MixITLossWrapper
?) - Checking with @popcornell for feedback/examples/common code etc..
- Factoring out common code between both methods.
Also, it would be amazing to have a notebook to show the mixing matrices (partitions as you call it?)
Would you be ok to do that?
It would be amazing !
asteroid/losses/pit_wrapper.py
Outdated
|
||
# Generate all the possible partitions | ||
parts = all_combinations(range(N)) | ||
for p, partition in enumerate(parts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this could be factored out between the two methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really because the two functions are different...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you run this little example the difference will be clearer I guess:
# Exmple
n_mix = 2
n_est = 4
n_src = 4 // 2
# Mixit
def combs(lst, k, l):
if l == 0:
yield []
else:
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
for r in combs(rest, k, l-1):
yield [list(c), *r]
parts = list(combs(range(n_est), n_src, n_mix))
print('Mixit')
for partition in parts:
print(partition)
# Mixit_gen
def all_combinations(lst):
all_combinations = []
for k in range(len(lst) + 1):
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
all_combinations.append([list(c), rest])
return all_combinations
parts = all_combinations(range(n_est))
print('Mixit_gen')
for partition in parts:
print(partition)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example.
I meant, the code under it.
Yes, the mixing matrices/partitions are different, but the loop seems quite similar right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes then the loop it's the same!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also the number of sources and number of mixtures are pretty much always known beforehand, maybe we can cache the mixing matrix at beginning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, caching the mixing matrix will be useful!
Let me know if you have any question on what's left to do, or if you need help for any of these tasks, we'll be happy to helps. |
mixit related tests removed
remove mixit-related stuff
Now there is a new class |
Cool, thanks ! |
Hey @giorgiacantisani, I made some changes to factor out the loop and fix the tests so that it passes. [[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[0, 1, 2], [3, 4, 6], [5, 7, 8]],
[[0, 1, 2], [3, 4, 7], [5, 6, 8]],
[[0, 1, 2], [3, 4, 8], [5, 6, 7]],
[[0, 1, 2], [3, 5, 6], [4, 7, 8]],
[[0, 1, 2], [3, 5, 7], [4, 6, 8]],
[[0, 1, 2], [3, 5, 8], [4, 6, 7]],
[[0, 1, 2], [3, 6, 7], [4, 5, 8]],
[[0, 1, 2], [3, 6, 8], [4, 5, 7]],
[[0, 1, 2], [3, 7, 8], [4, 5, 6]],
[[0, 1, 2], [4, 5, 6], [3, 7, 8]],
[[0, 1, 2], [4, 5, 7], [3, 6, 8]],
[[0, 1, 2], [4, 5, 8], [3, 6, 7]],
[[0, 1, 2], [4, 6, 7], [3, 5, 8]],
[[0, 1, 2], [4, 6, 8], [3, 5, 7]],
[[0, 1, 2], [4, 7, 8], [3, 5, 6]],
[[0, 1, 2], [5, 6, 7], [3, 4, 8]],
[[0, 1, 2], [5, 6, 8], [3, 4, 7]],
[[0, 1, 2], [5, 7, 8], [3, 4, 6]],
[[0, 1, 2], [6, 7, 8], [3, 4, 5]],
[[0, 1, 3], [2, 4, 5], [6, 7, 8]],
... So we compute the sum of ests (0, 1, 2), and the loss compared to mix[:, 0] 20 times, right? @giorgiacantisani @popcornell : I'm waiting for your approval of how it looks now and I'll merge |
Can you wait till i test it on WSJ0-2mix (saturday) ? |
Ok, sure. |
Maybe @etzinis would like to review actually? ^^ |
Yes, that's correct. I didn't even try to optimize it because normally the number of sources and mixes isn't too high... |
But I agree that it would be nice to optimize it! For the rest, all the new commits look good to me :) |
Seems good to me btw. However, I do not know why there is all this code complexity going around for generatign the combinations which can also be done in a single line: for comb_ind in itertools.product(range(num_mixtures), repeat=num_estimated_sources) E.g. for 2 mixtures and 4 estimated sources Out[2]: [(0, 0, 0, 0), |
Building on top of my previous comment: Each estimated source can only correspond to 1 mixture. So first you generate all possible assignments, compute the loss in each one of them and finally compute the min across all combinations. |
Thanks for your feedback!
This, I understand. But I don't understand the role of the output you showed in the previous snippet.. |
Ok, it's the mixture assignment of the source ^^ |
I think I understand it, in my implementation i generate the combinations as @etzinis suggests and get the other row by using not ( the mixing matrxi is 2xM where M is number of max sources). Then I use it literally as a mixing matrix: I simply multiply it by the estimated sources. No fancy indexing but lots of useless multiplications by zeros. Probably a masking approach gets the best of both worlds. Example: combos = torch.Tensor(np.array(list(itertools.product([0, 1], repeat=n_sources)))).to(preds.device)
loss = None
assigned_perm = None
for i in range(len(combos)):
# we evaluate sisdr loss for this
c_loss = self.loss_func(torch.sum(combos[i].unsqueeze(-1)*preds, 0), mixtures[0])
c_loss += self.loss_func(torch.sum((1. - combos[i]).unsqueeze(-1)*preds, 0), mixtures[1])
c_loss = c_loss.mean()
if loss is None or loss > c_loss:
loss = c_loss
assigned_perm = torch.stack((combos[i], 1. - combos[i])) |
@mpariente Everytime you need to obtain all possible assignments. So if your model outputs M sources then you need to have all possible assignments of estimated sources to mixtures, namely, num_mixtures^M. Btw generating the mixing matrices is exactly the same as generating the index assignments (I would say a bit faster since you need to do it only once in init). @popcornell has a nice implementation imho. |
@etzinis thanks for the feedback because the current way, without So maybe, we should switch to computing all the assignements all the time, in a more efficient manner. |
@popcornell Cool example, this works for two mixtures, does it generalize well for more? I currently don't have a lot of time to put into rewriting this sadly |
I would suggest this: In most of the cases you are going to use 2 mixtures, so there is no need to implement it for more (you can always do that later if you need to). Assuming that each mixture has the same number of sources, restricts the model a little bit in finding a good optimum from my experience. However, if the number of sources is quite large, then it would not mean much (you just compute some less combinations). |
The current code can compute the loss for any number of mixtures, but fixed number of sources assigned. |
I think what @giorgiacantisani already implemented is really good and can be merged (once tested on a dataset which i ll do). |
I agree that this is very good ! Maybe we should make the |
In my opinion yes, the current implementation is also speedy enough with generalized=True |
Thanks to @giorgiacantisani and everyone involved! I'll wait for the tests and merge it. I made the |
@popcornell Can you open the PR for the MixIT recipe 😃 |
you're welcome :) thank you for reviewing my code!
It's perfectly fine for me, |
Hi!
I implemented a mixture invariant training loss wrapper inspired by the paper "Unsupervised sound separation using mixtures of mixtures." arXiv preprint arXiv:2006.12701 (2020).
Specifically, there are two new options for PITLossWrapper:
mix_it
: Find the best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm. Valid for any number of mixtures as soon as they contain the same number of sources.mix_it_gen
: Find the best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm. Valid only for two mixtures, but those mixtures do not necessarily have to contain the same number of sources. It is allowed the case where one mixture is silent.