Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture Invariant Training #320

Merged
merged 20 commits into from
Nov 16, 2020
Merged

Conversation

giorgiacantisani
Copy link
Contributor

Hi!

I implemented a mixture invariant training loss wrapper inspired by the paper "Unsupervised sound separation using mixtures of mixtures." arXiv preprint arXiv:2006.12701 (2020).

Specifically, there are two new options for PITLossWrapper:

  1. mix_it: Find the best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm. Valid for any number of mixtures as soon as they contain the same number of sources.
  2. mix_it_gen: Find the best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm. Valid only for two mixtures, but those mixtures do not necessarily have to contain the same number of sources. It is allowed the case where one mixture is silent.

added mix_it and mix_it_gen loss wrappers
add test for mix_it and mix_it_gen
@popcornell
Copy link
Collaborator

Thank you very much for your pull request !

I have also a very naive implementation of MixIT on my local machine but you beat me on time ;) ! I think we can merge the two together (especially for the mixing matrix generation part).
Did you also had time to test it on some dataset ? Do you mind if I push some examples using mixit here ?

Also @mpariente how do you feel having mixit as an extension of current PITLossWrapper ? Theoretically it is sound but maybe could be confusing for some users.

@giorgiacantisani
Copy link
Contributor Author

Yes sure we can definitely merge them! I didn't focus to much on optimizing it and maybe you've a faster implementation.

I tested it with the MUSDB18 dataset and seems to work!

@mpariente
Copy link
Collaborator

Thank you very much for the implementation, that's great !

PIT and MixIT are essentially different training paradigm (though related), so I also think making a separate class would be a good idea.

It seems that the first commit in this PR reverts recent changes made to PITLossWrapper (black compliance, doc rendering and others). We want to keep those changes, not revert them, plus it makes it more difficult to review the proposed changes.
Could you start again from master and only apply changes related to MixIT please?
(No need to make a separate class for now, we'll converge on that afterwards)

Note to self: we start to have several extensions of PITLossWrapper, start to think about factoring out common logics/method.

@giorgiacantisani
Copy link
Contributor Author

Yes, I am sorry I was working on an old version of the file and didn't realize it until the PR. I'll fix this

updated to the last version of black compliance, doc rendering and others
Copy link
Collaborator

@mpariente mpariente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much better now, thank you!

There are few things to make before merging it

  • Re-add the tests that were deleted,
  • Making a separate class (MixITLossWrapper?)
  • Checking with @popcornell for feedback/examples/common code etc..
  • Factoring out common code between both methods.

Also, it would be amazing to have a notebook to show the mixing matrices (partitions as you call it?)

Would you be ok to do that?
It would be amazing !

tests/losses/pit_wrapper_test.py Outdated Show resolved Hide resolved
asteroid/losses/pit_wrapper.py Outdated Show resolved Hide resolved

# Generate all the possible partitions
parts = all_combinations(range(N))
for p, partition in enumerate(parts):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this could be factored out between the two methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really because the two functions are different...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you run this little example the difference will be clearer I guess:

# Exmple
n_mix = 2
n_est = 4
n_src = 4 // 2

# Mixit
def combs(lst, k, l):
    if l == 0:
        yield []
    else:
        for c in combinations(lst, k):
            rest = [x for x in lst if x not in c]
            for r in combs(rest, k, l-1):
                yield [list(c), *r]

parts = list(combs(range(n_est), n_src, n_mix))    
print('Mixit') 
for partition in parts:
    print(partition)

# Mixit_gen
def all_combinations(lst):
    all_combinations = []
    for k in range(len(lst) + 1):
        for c in combinations(lst, k):
            rest = [x for x in lst if x not in c]
            all_combinations.append([list(c), rest]) 
    return all_combinations

parts = all_combinations(range(n_est))  
print('Mixit_gen')  
for partition in parts:
    print(partition)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example.
I meant, the code under it.
Yes, the mixing matrices/partitions are different, but the loop seems quite similar right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes then the loop it's the same!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the number of sources and number of mixtures are pretty much always known beforehand, maybe we can cache the mixing matrix at beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, caching the mixing matrix will be useful!

asteroid/losses/pit_wrapper.py Outdated Show resolved Hide resolved
@mpariente
Copy link
Collaborator

Let me know if you have any question on what's left to do, or if you need help for any of these tasks, we'll be happy to helps.

@giorgiacantisani
Copy link
Contributor Author

Now there is a new class MixITLossWrapper with his own test!

@mpariente
Copy link
Collaborator

Cool, thanks !
I'll have a look soon.

@mpariente
Copy link
Collaborator

Hey @giorgiacantisani, I made some changes to factor out the loop and fix the tests so that it passes.
I checked and the partition generation is negligible compared to the actual loss computation, so I don't think we need to cache it.
OTOH, I think this could be optimized further
For example in the n_mix=3, n_src=3, there are 1680 partitions, and they look like this

 [[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
 [[0, 1, 2], [3, 4, 6], [5, 7, 8]],
 [[0, 1, 2], [3, 4, 7], [5, 6, 8]],
 [[0, 1, 2], [3, 4, 8], [5, 6, 7]],
 [[0, 1, 2], [3, 5, 6], [4, 7, 8]],
 [[0, 1, 2], [3, 5, 7], [4, 6, 8]],
 [[0, 1, 2], [3, 5, 8], [4, 6, 7]],
 [[0, 1, 2], [3, 6, 7], [4, 5, 8]],
 [[0, 1, 2], [3, 6, 8], [4, 5, 7]],
 [[0, 1, 2], [3, 7, 8], [4, 5, 6]],
 [[0, 1, 2], [4, 5, 6], [3, 7, 8]],
 [[0, 1, 2], [4, 5, 7], [3, 6, 8]],
 [[0, 1, 2], [4, 5, 8], [3, 6, 7]],
 [[0, 1, 2], [4, 6, 7], [3, 5, 8]],
 [[0, 1, 2], [4, 6, 8], [3, 5, 7]],
 [[0, 1, 2], [4, 7, 8], [3, 5, 6]],
 [[0, 1, 2], [5, 6, 7], [3, 4, 8]],
 [[0, 1, 2], [5, 6, 8], [3, 4, 7]],
 [[0, 1, 2], [5, 7, 8], [3, 4, 6]],
 [[0, 1, 2], [6, 7, 8], [3, 4, 5]],
 [[0, 1, 3], [2, 4, 5], [6, 7, 8]],
...

So we compute the sum of ests (0, 1, 2), and the loss compared to mix[:, 0] 20 times, right?
I don't think I'll spend time optimizing it now but let's keep it in mind 😉

@giorgiacantisani @popcornell : I'm waiting for your approval of how it looks now and I'll merge

@popcornell
Copy link
Collaborator

Can you wait till i test it on WSJ0-2mix (saturday) ?

@mpariente
Copy link
Collaborator

Ok, sure.

@mpariente
Copy link
Collaborator

Maybe @etzinis would like to review actually? ^^

@giorgiacantisani
Copy link
Contributor Author

Hey @giorgiacantisani, I made some changes to factor out the loop and fix the tests so that it passes.
I checked and the partition generation is negligible compared to the actual loss computation, so I don't think we need to cache it.
OTOH, I think this could be optimized further
For example in the n_mix=3, n_src=3, there are 1680 partitions, and they look like this

 [[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
 [[0, 1, 2], [3, 4, 6], [5, 7, 8]],
 [[0, 1, 2], [3, 4, 7], [5, 6, 8]],
 [[0, 1, 2], [3, 4, 8], [5, 6, 7]],
 [[0, 1, 2], [3, 5, 6], [4, 7, 8]],
 [[0, 1, 2], [3, 5, 7], [4, 6, 8]],
 [[0, 1, 2], [3, 5, 8], [4, 6, 7]],
 [[0, 1, 2], [3, 6, 7], [4, 5, 8]],
 [[0, 1, 2], [3, 6, 8], [4, 5, 7]],
 [[0, 1, 2], [3, 7, 8], [4, 5, 6]],
 [[0, 1, 2], [4, 5, 6], [3, 7, 8]],
 [[0, 1, 2], [4, 5, 7], [3, 6, 8]],
 [[0, 1, 2], [4, 5, 8], [3, 6, 7]],
 [[0, 1, 2], [4, 6, 7], [3, 5, 8]],
 [[0, 1, 2], [4, 6, 8], [3, 5, 7]],
 [[0, 1, 2], [4, 7, 8], [3, 5, 6]],
 [[0, 1, 2], [5, 6, 7], [3, 4, 8]],
 [[0, 1, 2], [5, 6, 8], [3, 4, 7]],
 [[0, 1, 2], [5, 7, 8], [3, 4, 6]],
 [[0, 1, 2], [6, 7, 8], [3, 4, 5]],
 [[0, 1, 3], [2, 4, 5], [6, 7, 8]],
...

So we compute the sum of ests (0, 1, 2), and the loss compared to mix[:, 0] 20 times, right?
I don't think I'll spend time optimizing it now but let's keep it in mind 😉

Yes, that's correct. I didn't even try to optimize it because normally the number of sources and mixes isn't too high...

@giorgiacantisani
Copy link
Contributor Author

But I agree that it would be nice to optimize it! For the rest, all the new commits look good to me :)

@etzinis
Copy link
Contributor

etzinis commented Nov 12, 2020

Seems good to me btw. However, I do not know why there is all this code complexity going around for generatign the combinations which can also be done in a single line:

for comb_ind in itertools.product(range(num_mixtures), repeat=num_estimated_sources)

E.g. for 2 mixtures and 4 estimated sources
`In [2]: [i for i in itertools.product(range(2), repeat=4)]

Out[2]:

[(0, 0, 0, 0),
(0, 0, 0, 1),
(0, 0, 1, 0),
(0, 0, 1, 1),
(0, 1, 0, 0),
(0, 1, 0, 1),
(0, 1, 1, 0),
(0, 1, 1, 1),
(1, 0, 0, 0),
(1, 0, 0, 1),
(1, 0, 1, 0),
(1, 0, 1, 1),
(1, 1, 0, 0),
(1, 1, 0, 1),
(1, 1, 1, 0),
(1, 1, 1, 1)]`

@etzinis
Copy link
Contributor

etzinis commented Nov 12, 2020

Building on top of my previous comment: Each estimated source can only correspond to 1 mixture. So first you generate all possible assignments, compute the loss in each one of them and finally compute the min across all combinations.

@mpariente
Copy link
Collaborator

Thanks for your feedback!

Each estimated source can only correspond to 1 mixture. So first you generate all possible assignments, compute the loss in each one of them and finally compute the min across all combinations.

This, I understand. But I don't understand the role of the output you showed in the previous snippet..
Can you please elaborate?

@mpariente
Copy link
Collaborator

Ok, it's the mixture assignment of the source ^^
I didn't spend much time on this so I'm not sure but you usually consider all possible assignments? Or you assume that mixtures have the same number of sources?
What is the max num of sources and mixtures you tried MixIT with?

@popcornell
Copy link
Collaborator

popcornell commented Nov 12, 2020

This, I understand. But I don't understand the role of the output you showed in the previous snippet..
Can you please elaborate?

I think I understand it, in my implementation i generate the combinations as @etzinis suggests and get the other row by using not ( the mixing matrxi is 2xM where M is number of max sources). Then I use it literally as a mixing matrix: I simply multiply it by the estimated sources. No fancy indexing but lots of useless multiplications by zeros. Probably a masking approach gets the best of both worlds.

Example:

        combos = torch.Tensor(np.array(list(itertools.product([0, 1], repeat=n_sources)))).to(preds.device)
        loss = None
        assigned_perm = None
        for i in range(len(combos)):
            # we evaluate sisdr loss for this
            c_loss = self.loss_func(torch.sum(combos[i].unsqueeze(-1)*preds, 0), mixtures[0])
            c_loss += self.loss_func(torch.sum((1. - combos[i]).unsqueeze(-1)*preds, 0), mixtures[1])
            c_loss = c_loss.mean()
            if loss is None or loss > c_loss:
                loss = c_loss
                assigned_perm = torch.stack((combos[i], 1. - combos[i]))

@etzinis
Copy link
Contributor

etzinis commented Nov 12, 2020

@mpariente Everytime you need to obtain all possible assignments. So if your model outputs M sources then you need to have all possible assignments of estimated sources to mixtures, namely, num_mixtures^M. Btw generating the mixing matrices is exactly the same as generating the index assignments (I would say a bit faster since you need to do it only once in init). @popcornell has a nice implementation imho.

@mpariente
Copy link
Collaborator

@etzinis thanks for the feedback because the current way, without generalized=True, it doesn't consider all assignments, but the ones that assign equal number of sources to each mixture (if I understood correctly).

So maybe, we should switch to computing all the assignements all the time, in a more efficient manner.

@mpariente
Copy link
Collaborator

@popcornell Cool example, this works for two mixtures, does it generalize well for more?

I currently don't have a lot of time to put into rewriting this sadly

@etzinis
Copy link
Contributor

etzinis commented Nov 12, 2020

I would suggest this: In most of the cases you are going to use 2 mixtures, so there is no need to implement it for more (you can always do that later if you need to). Assuming that each mixture has the same number of sources, restricts the model a little bit in finding a good optimum from my experience. However, if the number of sources is quite large, then it would not mean much (you just compute some less combinations).

@mpariente
Copy link
Collaborator

The current code can compute the loss for any number of mixtures, but fixed number of sources assigned.
I'm not sure that I'd like to limit the users to two sources only, specially if it's already implemented..

@popcornell
Copy link
Collaborator

popcornell commented Nov 12, 2020

I think what @giorgiacantisani already implemented is really good and can be merged (once tested on a dataset which i ll do).
Let's keep these observations for optimizing the loss in the future with a second pull request.

@mpariente
Copy link
Collaborator

I agree that this is very good !
I checked again and the current code approaches the problem differently: it produces the indices the other way around, from the mixture to the sources.

Maybe we should make the generalized the default (where possible), WDYT?

@popcornell
Copy link
Collaborator

popcornell commented Nov 14, 2020

In my opinion yes, the current implementation is also speedy enough with generalized=True
I've tested it and it works so I think we can merge.
I have a recipe ready I can merge it afterwards (it is here: https://github.com/mpariente/asteroid/tree/mixit_recipe)

@mpariente
Copy link
Collaborator

Thanks to @giorgiacantisani and everyone involved! I'll wait for the tests and merge it.

I made the generalized the default because this seems to be the intended behavior from the original authors, hope you're ok with this @giorgiacantisani

@mpariente mpariente merged commit aedb22f into asteroid-team:master Nov 16, 2020
@mpariente
Copy link
Collaborator

@popcornell Can you open the PR for the MixIT recipe 😃

@giorgiacantisani
Copy link
Contributor Author

Thanks to @giorgiacantisani and everyone involved! I'll wait for the tests and merge it.

you're welcome :) thank you for reviewing my code!

I made the generalized the default because this seems to be the intended behavior from the original authors, hope you're ok with this @giorgiacantisani

It's perfectly fine for me, generalized was not the default because I did it in a second moment and not because of a specific reason!

@popcornell popcornell mentioned this pull request Jul 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants